import os
from jittor_utils.misc import download_and_extract_archive, check_integrity
from PIL import Image
import sys, pickle
import numpy as np
from jittor.dataset import Dataset, dataset_root
[文档]
class CIFAR10(Dataset):
'''
`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ 数据集。
参数:
- root (str, optional): 数据集的根目录, 其中应该包含或将要保存包含 ``cifar-10-batches-py`` 的目录, 如果设置了download为True, 则会自动下载并解压数据集, 默认值: ``dataset_root + '/cifar_data/'`` 。
- train (bool, optional): 如果为True, 则创建训练集的数据集对象;如果为False, 则创建测试集的数据集对象, 默认值: True。
- transform (callable, optional): 一个用于图像变换的函数或可调用对象, 它接收一个PIL图片, 并返回经过变换的图片。例如, ``transforms.RandomCrop`` , 默认值: None。
- target_transform (callable, optional): 一个用于标签变换的函数或可调用对象, 它接收目标标签, 并返回转换后的标签。默认值: None
- download (bool, optional): 如果为True, 则从互联网下载数据集并将其保存到root目录。如果数据集已经下载, 不会再次下载, 默认值: True。
属性:
- root (str): 数据集的根目录。
- data (numpy.ndarray): 图像数据, 形状为(50000, 3, 32, 32)或(10000, 3, 32, 32)。
- targets (numpy.ndarray): 标签数据, 形状为(50000,)或(10000,)。
- classes (list): 类别名称列表, 形状为(10,)。
- class_to_idx (dict): 类别名称到索引的映射:({'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7,'ship': 8, 'truck': 9})。
代码示例:
>>> from jittor.dataset.cifar import CIFAR10
>>> dataset = CIFAR10()
'''
base_folder = 'cifar-10-batches-py'
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
def __init__(self, root=dataset_root+"/cifar_data/", train=True, transform=None, target_transform=None,
download=True):
super(CIFAR10, self).__init__()
self.root = root
self.transform=transform
self.target_transform=target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
self.data = []
self.targets = []
# now load the picked numpy arrays
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
if sys.version_info[0] == 2:
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
self._load_meta()
def _load_meta(self):
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
with open(path, 'rb') as infile:
if sys.version_info[0] == 2:
data = pickle.load(infile)
else:
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
def _check_integrity(self):
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not check_integrity(fpath, md5):
return False
return True
def download(self):
if self._check_integrity():
print('Files already downloaded and verified')
return
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
[文档]
class CIFAR100(CIFAR10):
'''
`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ 数据集, 是 ``CIFAR10`` 数据集的子类。
参数:
- split (str, optional) - 指定数据集分割的类型, 'train' 表示训练集, 'test' 表示测试集, 默认值: 'train' 。
- transform (callable, optional) - 一个函数或transform对象, 用于对样本进行处理, 默认值: `None` 。
- target_transform (callable, optional) - 一个函数或transform对象, 用于对标签进行处理, 默认值: `None` 。
- download (bool, optional) - 是否下载数据集, 如果数据集未下载, 则设为 `True`, 默认值: `False` 。
- batch_size (int, optional) - 每个batch中的样本数, 默认值: 1。
- shuffle (bool, optional) - 是否在每个epoch开始时打乱数据, 默认值: `False` 。
- num_workers (int, optional) - 加载数据时使用的子进程数量, 默认值: 0。
属性:
- base_folder (str) - CIFAR100数据集解压后的文件夹名称。
- url (str) - CIFAR100数据集下载链接。
- filename (str) - CIFAR100数据集压缩文件的文件名。
- tgz_md5 (str) - CIFAR100压缩文件的MD5校验码。
- train_list (list of list) - 训练集文件的名称列表及其MD5校验码。
- test_list (list of list) - 测试集文件的名称列表及其MD5校验码。
- meta (dict) - 包含标签和层级结构信息的元数据字典。
代码示例:
>>> from jittor.dataset.cifar import CIFAR100
>>> dataset = CIFAR100(split='train')
'''
base_folder = 'cifar-100-python'
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
]
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]
meta = {
'filename': 'meta',
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48',
}