jittor.dataset¶
这里是Jittor的数据集模块的API文档,您可以通过from jittor import dataset
来获取该模块。
- class jittor.dataset.CIFAR10(root='/home/zwy/.cache/jittor/dataset/cifar_data/', train=True, transform=None, target_transform=None, download=True)[源代码]¶
CIFAR10 Dataset.
- Args:
- root (string): Root directory of dataset where directory
cifar-10-batches-py
exists or will be saved to if download is set to True.- train (bool, optional): If True, creates dataset from training set, otherwise
creates from test set.
- transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g,
transforms.RandomCrop
- target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
- download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not downloaded again.
Example:
from jittor.dataset.cifar import CIFAR10 a = CIFAR10() a.set_attrs(batch_size=16) for imgs, labels in a: print(imgs.shape, labels.shape) break
- base_folder = 'cifar-10-batches-py'¶
- filename = 'cifar-10-python.tar.gz'¶
- meta = {'filename': 'batches.meta', 'key': 'label_names', 'md5': '5ff9c542aee3614f3951f8cda6e48888'}¶
- test_list = [['test_batch', '40351d587109b95175f43aff81a1287e']]¶
- tgz_md5 = 'c58f30108f718f92721af3b95e74349a'¶
- train_list = [['data_batch_1', 'c99cafc152244af753f735de768cd75f'], ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], ['data_batch_4', '634d18415352ddfa80567beed471001a'], ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb']]¶
- url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'¶
- class jittor.dataset.CIFAR100(root='/home/zwy/.cache/jittor/dataset/cifar_data/', train=True, transform=None, target_transform=None, download=True)[源代码]¶
CIFAR100 Dataset.
This is a subclass of the CIFAR10 Dataset.
Example:
from jittor.dataset.cifar import CIFAR100 a = CIFAR100() a.set_attrs(batch_size=16) for imgs, labels in a: print(imgs.shape, labels.shape) break
- base_folder = 'cifar-100-python'¶
- filename = 'cifar-100-python.tar.gz'¶
- meta = {'filename': 'meta', 'key': 'fine_label_names', 'md5': '7973b15100ade9c7d40fb424638fde48'}¶
- test_list = [['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc']]¶
- tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'¶
- train_list = [['train', '16019d7e3df5f24257cddd939b257f8d']]¶
- url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'¶
- jittor.dataset.DataLoader(dataset: jittor.dataset.dataset.Dataset, *args, **kargs)[源代码]¶
Simple dataloader.
Example:
train_dir = './data/celebA_train' train_dataset = ImageFolder(train_dir) dataloader = jt.dataset.DataLoader(train_dataset, batch_size=8)
- class jittor.dataset.Dataset(batch_size=16, shuffle=False, drop_last=False, num_workers=0, buffer_size=536870912, stop_grad=True, keep_numpy_array=False, endless=False)[源代码]¶
Base class for reading data.
Args:
[in] batch_size(int): batch size, default 16. [in] shuffle(bool): shuffle at each epoch, default False. [in] drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True. [in] num_workers(int): number of workers for loading data. [in] buffer_size(int): buffer size for each worker in bytes, default(512MB). [in] keep_numpy_array(bool): return numpy array rather than jittor array, default(False). [in] endless(bool): will this dataset yield data forever, default(False).
Example:
class YourDataset(Dataset): def __init__(self): super().__init__() self.set_attrs(total_len=1024) def __getitem__(self, k): return k, k*k dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) for x, y in dataset: ......
- collate_batch(batch)[源代码]¶
Puts each data field into a tensor with outer dimension batch size.
Args:
[in] batch(list): A list of variables, such as jt.var, Image.Image, np.ndarray, int, float, str and so on.
- display_worker_status()[源代码]¶
Display dataset worker status, when dataset.num_workers > 0, it will display infomation blow:
progress:479/5005 batch(s): 0.302 wait(s):0.000 recv(s): 0.069 to_jittor(s):0.021 recv_raw_call: 6720.0 last 10 workers: [6, 7, 3, 0, 2, 4, 7, 5, 6, 1] ID wait(s) load(s) send(s) total #0 0.000 1.340 2.026 3.366 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #1 0.000 1.451 3.607 5.058 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #2 0.000 1.278 1.235 2.513 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #3 0.000 1.426 1.927 3.353 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #4 0.000 1.452 1.074 2.526 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #5 0.000 1.422 3.204 4.625 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #6 0.000 1.445 1.953 3.398 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) #7 0.000 1.582 0.507 2.090 Buffer(free=0.000% l=308283552 r=308283552 size=536870912)
Meaning of the outputs:
progress: dataset loading progress (current/total)
batch: batch time, exclude data loading time
wait: time of main proc wait worker proc
recv: time of recv batch data
to_jittor: time of batch data to jittor variable
recv_raw_call: total number of underlying recv_raw called
last 10 workers: id of last 10 workers which main proc load from.
- table meaning
ID: worker id
wait: worker wait time
open: worker image open time
load: worker load time
buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes).
Example:
from jittor.dataset import Dataset class YourDataset(Dataset): pass dataset = YourDataset().set_attrs(num_workers=8) for x, y in dataset: dataset.display_worker_status()
- set_attrs(**kw)[源代码]¶
You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size.
Example:
dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)
Attrs:
batch_size(int): batch size, default 16.
total_len(int): total lenght.
shuffle(bool): shuffle at each epoch, default False.
drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.
num_workers: number of workers for loading data
buffer_size: buffer size for each worker in bytes, default(512MB).
stop_grad: stop grad for data, default(True).
- class jittor.dataset.ImageFolder(root, transform=None)[源代码]¶
A image classify dataset, load image and label from directory:
* root/label1/img1.png * root/label1/img2.png * ... * root/label2/img1.png * root/label2/img2.png * ...
Args:
[in] root(string): Root directory path.
Attributes:
* classes(list): List of the class names. * class_to_idx(dict): map from class_name to class_index. * imgs(list): List of (image_path, class_index) tuples
Example:
train_dir = './data/celebA_train' train_loader = ImageFolder(train_dir).set_attrs(batch_size=batch_size, shuffle=True) for batch_idx, (x_, target) in enumerate(train_loader): ...
- class jittor.dataset.MNIST(data_root='/home/zwy/.cache/jittor/dataset/mnist_data/', train=True, download=True, batch_size=16, shuffle=False, transform=None)[源代码]¶
Jittor’s own class for loading MNIST dataset.
Args:
[in] data_root(str): your data root. [in] train(bool): choose model train or val. [in] download(bool): Download data automatically if download is Ture. [in] batch_size(int): Data batch size. [in] shuffle(bool): Shuffle data if true. [in] transform(jittor.transform): transform data.
Example:
from jittor.dataset.mnist import MNIST train_loader = MNIST(train=True).set_attrs(batch_size=16, shuffle=True) for i, (imgs, target) in enumerate(train_loader): ...
- class jittor.dataset.RandomSampler(dataset, replacement=False, num_samples=None)[源代码]¶
- property num_samples¶
- jittor.dataset.TensorDataset¶
alias of
jittor.dataset.dataset.VarDataset
- class jittor.dataset.VOC(data_root='/home/zwy/.cache/jittor/dataset/voc/', split='train')[源代码]¶
Jittor’s own class for loading VOC dataset.
Args:
[in] data_root(str): your data root. [in] split(str): which split you want to use, train or val.
Attribute:
NUM_CLASSES: Number of total categories, default is 21.
Example:
from jittor.dataset.voc import VOC train_loader = VOC(data_root='...').set_attrs(batch_size=16, shuffle=True) for i, (imgs, target) in enumerate(train_loader): ...
- NUM_CLASSES = 21¶
- class jittor.dataset.VarDataset(*args)[源代码]¶
Dataset using Var directly, TensorDataset is alias of VarDataset, Example:
import jittor as jt from jittor.dataset import VarDataset
x = jt.array([1,2,3]) y = jt.array([4,5,6]) z = jt.array([7,8,9]) dataset = VarDataset(x, y, z) dataset.set_attrs(batch_size=1)
- for a,b,c in dataset:
print(a,b,c)
# will print # 1,4,7 # 2,5,8 # 3,6,9