jittor.dataset.cifar 源代码


import os
from jittor_utils.misc import download_and_extract_archive, check_integrity
from PIL import Image
import sys, pickle
import numpy as np
from jittor.dataset import Dataset, dataset_root

[文档] class CIFAR10(Dataset): ''' `CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ 数据集。 参数: - root (str, optional): 数据集的根目录, 其中应该包含或将要保存包含 ``cifar-10-batches-py`` 的目录, 如果设置了download为True, 则会自动下载并解压数据集, 默认值: ``dataset_root + '/cifar_data/'`` 。 - train (bool, optional): 如果为True, 则创建训练集的数据集对象;如果为False, 则创建测试集的数据集对象, 默认值: True。 - transform (callable, optional): 一个用于图像变换的函数或可调用对象, 它接收一个PIL图片, 并返回经过变换的图片。例如, ``transforms.RandomCrop`` , 默认值: None。 - target_transform (callable, optional): 一个用于标签变换的函数或可调用对象, 它接收目标标签, 并返回转换后的标签。默认值: None - download (bool, optional): 如果为True, 则从互联网下载数据集并将其保存到root目录。如果数据集已经下载, 不会再次下载, 默认值: True。 属性: - root (str): 数据集的根目录。 - data (numpy.ndarray): 图像数据, 形状为(50000, 3, 32, 32)或(10000, 3, 32, 32)。 - targets (numpy.ndarray): 标签数据, 形状为(50000,)或(10000,)。 - classes (list): 类别名称列表, 形状为(10,)。 - class_to_idx (dict): 类别名称到索引的映射:({'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7,'ship': 8, 'truck': 9})。 代码示例: >>> from jittor.dataset.cifar import CIFAR10 >>> dataset = CIFAR10() ''' base_folder = 'cifar-10-batches-py' url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" filename = "cifar-10-python.tar.gz" tgz_md5 = 'c58f30108f718f92721af3b95e74349a' train_list = [ ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], ['data_batch_4', '634d18415352ddfa80567beed471001a'], ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], ] test_list = [ ['test_batch', '40351d587109b95175f43aff81a1287e'], ] meta = { 'filename': 'batches.meta', 'key': 'label_names', 'md5': '5ff9c542aee3614f3951f8cda6e48888', } def __init__(self, root=dataset_root+"/cifar_data/", train=True, transform=None, target_transform=None, download=True): super(CIFAR10, self).__init__() self.root = root self.transform=transform self.target_transform=target_transform self.train = train # training set or test set if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') if self.train: downloaded_list = self.train_list else: downloaded_list = self.test_list self.data = [] self.targets = [] # now load the picked numpy arrays for file_name, checksum in downloaded_list: file_path = os.path.join(self.root, self.base_folder, file_name) with open(file_path, 'rb') as f: if sys.version_info[0] == 2: entry = pickle.load(f) else: entry = pickle.load(f, encoding='latin1') self.data.append(entry['data']) if 'labels' in entry: self.targets.extend(entry['labels']) else: self.targets.extend(entry['fine_labels']) self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC self._load_meta() def _load_meta(self): path = os.path.join(self.root, self.base_folder, self.meta['filename']) if not check_integrity(path, self.meta['md5']): raise RuntimeError('Dataset metadata file not found or corrupted.' + ' You can use download=True to download it') with open(path, 'rb') as infile: if sys.version_info[0] == 2: data = pickle.load(infile) else: data = pickle.load(infile, encoding='latin1') self.classes = data[self.meta['key']] self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} def __getitem__(self, index): img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.data) def _check_integrity(self): root = self.root for fentry in (self.train_list + self.test_list): filename, md5 = fentry[0], fentry[1] fpath = os.path.join(root, self.base_folder, filename) if not check_integrity(fpath, md5): return False return True def download(self): if self._check_integrity(): print('Files already downloaded and verified') return download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) def extra_repr(self): return "Split: {}".format("Train" if self.train is True else "Test")
[文档] class CIFAR100(CIFAR10): ''' `CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ 数据集, 是 ``CIFAR10`` 数据集的子类。 参数: - split (str, optional) - 指定数据集分割的类型, 'train' 表示训练集, 'test' 表示测试集, 默认值: 'train' 。 - transform (callable, optional) - 一个函数或transform对象, 用于对样本进行处理, 默认值: `None` 。 - target_transform (callable, optional) - 一个函数或transform对象, 用于对标签进行处理, 默认值: `None` 。 - download (bool, optional) - 是否下载数据集, 如果数据集未下载, 则设为 `True`, 默认值: `False` 。 - batch_size (int, optional) - 每个batch中的样本数, 默认值: 1。 - shuffle (bool, optional) - 是否在每个epoch开始时打乱数据, 默认值: `False` 。 - num_workers (int, optional) - 加载数据时使用的子进程数量, 默认值: 0。 属性: - base_folder (str) - CIFAR100数据集解压后的文件夹名称。 - url (str) - CIFAR100数据集下载链接。 - filename (str) - CIFAR100数据集压缩文件的文件名。 - tgz_md5 (str) - CIFAR100压缩文件的MD5校验码。 - train_list (list of list) - 训练集文件的名称列表及其MD5校验码。 - test_list (list of list) - 测试集文件的名称列表及其MD5校验码。 - meta (dict) - 包含标签和层级结构信息的元数据字典。 代码示例: >>> from jittor.dataset.cifar import CIFAR100 >>> dataset = CIFAR100(split='train') ''' base_folder = 'cifar-100-python' url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" filename = "cifar-100-python.tar.gz" tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' train_list = [ ['train', '16019d7e3df5f24257cddd939b257f8d'], ] test_list = [ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], ] meta = { 'filename': 'meta', 'key': 'fine_label_names', 'md5': '7973b15100ade9c7d40fb424638fde48', }