jittor.dataset

这里是Jittor的数据集模块的API文档,您可以通过from jittor import dataset来获取该模块。

class jittor.dataset.BatchSampler(sampler, batch_size, drop_last)[源代码]
class jittor.dataset.CIFAR10(root='/home/zwy/.cache/jittor/dataset/cifar_data/', train=True, transform=None, target_transform=None, download=True)[源代码]

CIFAR10 Dataset.

Args:
root (string): Root directory of dataset where directory

cifar-10-batches-py exists or will be saved to if download is set to True.

train (bool, optional): If True, creates dataset from training set, otherwise

creates from test set.

transform (callable, optional): A function/transform that takes in an PIL image

and returns a transformed version. E.g, transforms.RandomCrop

target_transform (callable, optional): A function/transform that takes in the

target and transforms it.

download (bool, optional): If true, downloads the dataset from the internet and

puts it in root directory. If dataset is already downloaded, it is not downloaded again.

Example:

from jittor.dataset.cifar import CIFAR10
a = CIFAR10()
a.set_attrs(batch_size=16)
for imgs, labels in a:
    print(imgs.shape, labels.shape)
    break
base_folder = 'cifar-10-batches-py'
download()[源代码]
extra_repr()[源代码]
filename = 'cifar-10-python.tar.gz'
meta = {'filename': 'batches.meta', 'key': 'label_names', 'md5': '5ff9c542aee3614f3951f8cda6e48888'}
test_list = [['test_batch', '40351d587109b95175f43aff81a1287e']]
tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
train_list = [['data_batch_1', 'c99cafc152244af753f735de768cd75f'], ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], ['data_batch_4', '634d18415352ddfa80567beed471001a'], ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb']]
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
class jittor.dataset.CIFAR100(root='/home/zwy/.cache/jittor/dataset/cifar_data/', train=True, transform=None, target_transform=None, download=True)[源代码]

CIFAR100 Dataset.

This is a subclass of the CIFAR10 Dataset.

Example:

from jittor.dataset.cifar import CIFAR100
a = CIFAR100()
a.set_attrs(batch_size=16)
for imgs, labels in a:
    print(imgs.shape, labels.shape)
    break
base_folder = 'cifar-100-python'
filename = 'cifar-100-python.tar.gz'
meta = {'filename': 'meta', 'key': 'fine_label_names', 'md5': '7973b15100ade9c7d40fb424638fde48'}
test_list = [['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc']]
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [['train', '16019d7e3df5f24257cddd939b257f8d']]
url = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
class jittor.dataset.Dataset(batch_size=16, shuffle=False, drop_last=False, num_workers=0, buffer_size=536870912, stop_grad=True, keep_numpy_array=False, endless=False)[源代码]

Base class for reading data.

Args:

[in] batch_size(int): batch size, default 16.
[in] shuffle(bool): shuffle at each epoch, default False.
[in] drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.
[in] num_workers(int): number of workers for loading data.
[in] buffer_size(int): buffer size for each worker in bytes, default(512MB).
[in] keep_numpy_array(bool): return numpy array rather than jittor array, default(False).
[in] endless(bool): will this dataset yield data forever, default(False).

Example:

class YourDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.set_attrs(total_len=1024)

    def __getitem__(self, k):
        return k, k*k

dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)
for x, y in dataset:
    ......
collate_batch(batch)[源代码]

Puts each data field into a tensor with outer dimension batch size.

Args:

[in] batch(list): A list of variables, such as jt.var, Image.Image, np.ndarray, int, float, str and so on.
display_worker_status()[源代码]

Display dataset worker status, when dataset.num_workers > 0, it will display infomation blow:

progress:479/5005
batch(s): 0.302 wait(s):0.000
recv(s): 0.069  to_jittor(s):0.021
recv_raw_call: 6720.0
last 10 workers: [6, 7, 3, 0, 2, 4, 7, 5, 6, 1]
ID      wait(s) load(s) send(s) total
#0      0.000   1.340   2.026   3.366   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#1      0.000   1.451   3.607   5.058   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#2      0.000   1.278   1.235   2.513   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#3      0.000   1.426   1.927   3.353   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#4      0.000   1.452   1.074   2.526   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#5      0.000   1.422   3.204   4.625   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#6      0.000   1.445   1.953   3.398   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#7      0.000   1.582   0.507   2.090   Buffer(free=0.000% l=308283552 r=308283552 size=536870912)

Meaning of the outputs:

  • progress: dataset loading progress (current/total)

  • batch: batch time, exclude data loading time

  • wait: time of main proc wait worker proc

  • recv: time of recv batch data

  • to_jittor: time of batch data to jittor variable

  • recv_raw_call: total number of underlying recv_raw called

  • last 10 workers: id of last 10 workers which main proc load from.

  • table meaning
    • ID: worker id

    • wait: worker wait time

    • open: worker image open time

    • load: worker load time

    • buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes).

Example:

from jittor.dataset import Dataset
class YourDataset(Dataset):
    pass
dataset = YourDataset().set_attrs(num_workers=8)
for x, y in dataset:
    dataset.display_worker_status()
reset()[源代码]
set_attrs(**kw)[源代码]

You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size.

Example:

dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)

Attrs:

  • batch_size(int): batch size, default 16.

  • total_len(int): total lenght.

  • shuffle(bool): shuffle at each epoch, default False.

  • drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.

  • num_workers: number of workers for loading data

  • buffer_size: buffer size for each worker in bytes, default(512MB).

  • stop_grad: stop grad for data, default(True).

terminate()[源代码]

Terminate is used to terminate multi-process worker reading data.

to_jittor(batch)[源代码]

Change batch data to jittor array, such as np.ndarray, int, and float.

class jittor.dataset.ImageFolder(root, transform=None)[源代码]

A image classify dataset, load image and label from directory:

* root/label1/img1.png
* root/label1/img2.png
* ...
* root/label2/img1.png
* root/label2/img2.png
* ...

Args:

[in] root(string): Root directory path.

Attributes:

* classes(list): List of the class names.
* class_to_idx(dict): map from class_name to class_index.
* imgs(list): List of (image_path, class_index) tuples

Example:

train_dir = './data/celebA_train'
train_loader = ImageFolder(train_dir).set_attrs(batch_size=batch_size, shuffle=True)
for batch_idx, (x_, target) in enumerate(train_loader):
    ...
class jittor.dataset.MNIST(data_root='/home/zwy/.cache/jittor/dataset/mnist_data/', train=True, download=True, batch_size=16, shuffle=False, transform=None)[源代码]

Jittor’s own class for loading MNIST dataset.

Args:

[in] data_root(str): your data root.
[in] train(bool): choose model train or val.
[in] download(bool): Download data automatically if download is Ture.
[in] batch_size(int): Data batch size.
[in] shuffle(bool): Shuffle data if true.
[in] transform(jittor.transform): transform data.

Example:

from jittor.dataset.mnist import MNIST
train_loader = MNIST(train=True).set_attrs(batch_size=16, shuffle=True)
for i, (imgs, target) in enumerate(train_loader):
    ...
download_url()[源代码]

Download mnist data set function, this function will be called when download is True.

class jittor.dataset.RandomSampler(dataset, replacement=False, num_samples=None)[源代码]
property num_samples
class jittor.dataset.Sampler(dataset)[源代码]
class jittor.dataset.SequentialSampler(dataset)[源代码]
class jittor.dataset.SubsetRandomSampler(dataset, indice)[源代码]
jittor.dataset.TensorDataset

alias of jittor.dataset.dataset.VarDataset

class jittor.dataset.VOC(data_root='/home/zwy/.cache/jittor/dataset/voc/', split='train')[源代码]

Jittor’s own class for loading VOC dataset.

Args:

[in] data_root(str): your data root.
[in] split(str): which split you want to use, train or val.

Attribute:

NUM_CLASSES: Number of total categories, default is 21.

Example:

from jittor.dataset.voc import VOC
train_loader = VOC(data_root='...').set_attrs(batch_size=16, shuffle=True)
for i, (imgs, target) in enumerate(train_loader):
    ...
NUM_CLASSES = 21
class jittor.dataset.VarDataset(*args)[源代码]

Dataset using Var directly, TensorDataset is alias of VarDataset, Example:

import jittor as jt from jittor.dataset import VarDataset

x = jt.array([1,2,3]) y = jt.array([4,5,6]) z = jt.array([7,8,9]) dataset = VarDataset(x, y, z) dataset.set_attrs(batch_size=1)

for a,b,c in dataset:

print(a,b,c)

# will print # 1,4,7 # 2,5,8 # 3,6,9

collate_batch(batch)[源代码]

Puts each data field into a tensor with outer dimension batch size.

Args:

[in] batch(list): A list of variables, such as jt.var, Image.Image, np.ndarray, int, float, str and so on.