jittor.dataset

这里是Jittor的数据集模块的API文档,您可以通过from jittor import dataset来获取该模块。

class jittor.dataset.Dataset(batch_size=16, shuffle=False, drop_last=False, num_workers=0, buffer_size=536870912, stop_grad=True)[源代码]

Base class for reading data.

Args:

[in] batch_size(int): batch size, default 16.
[in] shuffle(bool): shuffle at each epoch, default False.
[in] drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.
[in] num_workers(int): number of workers for loading data.
[in] buffer_size(int): buffer size for each worker in bytes, default(512MB).

Example:

class YourDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.set_attrs(total_len=1024)

    def __getitem__(self, k):
        return k, k*k

dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)
for x, y in dataset:
    ......
collate_batch(batch)[源代码]

Puts each data field into a tensor with outer dimension batch size.

Args:

[in] batch(list): A list of variables, such as jt.var, Image.Image, np.ndarray, int, float, str and so on.
display_worker_status()[源代码]

Display dataset worker status, when dataset.num_workers > 0, it will display infomation blow:

progress:479/5005
batch(s): 0.302 wait(s):0.000
recv(s): 0.069  to_jittor(s):0.021
recv_raw_call: 6720.0
last 10 workers: [6, 7, 3, 0, 2, 4, 7, 5, 6, 1]
ID      wait(s) load(s) send(s) total
#0      0.000   1.340   2.026   3.366   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#1      0.000   1.451   3.607   5.058   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#2      0.000   1.278   1.235   2.513   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#3      0.000   1.426   1.927   3.353   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#4      0.000   1.452   1.074   2.526   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#5      0.000   1.422   3.204   4.625   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#6      0.000   1.445   1.953   3.398   Buffer(free=0.000% l=462425368 r=462425368 size=536870912)
#7      0.000   1.582   0.507   2.090   Buffer(free=0.000% l=308283552 r=308283552 size=536870912)

Meaning of the outputs:

  • progress: dataset loading progress (current/total)

  • batch: batch time, exclude data loading time

  • wait: time of main proc wait worker proc

  • recv: time of recv batch data

  • to_jittor: time of batch data to jittor variable

  • recv_raw_call: total number of underlying recv_raw called

  • last 10 workers: id of last 10 workers which main proc load from.

  • table meaning
    • ID: worker id

    • wait: worker wait time

    • open: worker image open time

    • load: worker load time

    • buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes).

Example:

from jittor.dataset import Dataset
class YourDataset(Dataset):
    pass
dataset = YourDataset().set_attrs(num_workers=8)
for x, y in dataset:
    dataset.display_worker_status()
set_attrs(**kw)[源代码]

You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size.

Example:

dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)

Attrs:

  • batch_size(int): batch size, default 16.

  • totol_len(int): totol lenght.

  • shuffle(bool): shuffle at each epoch, default False.

  • drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.

  • num_workers: number of workers for loading data

  • buffer_size: buffer size for each worker in bytes, default(512MB).

  • stop_grad: stop grad for data, default(True).

terminate()[源代码]

Terminate is used to terminate multi-process worker reading data.

to_jittor(batch)[源代码]

Change batch data to jittor array, such as np.ndarray, int, and float.

class jittor.dataset.ImageFolder(root, transform=None)[源代码]

A image classify dataset, load image and label from directory:

* root/label1/img1.png
* root/label1/img2.png
* ...
* root/label2/img1.png
* root/label2/img2.png
* ...

Args:

[in] root(string): Root directory path.

Attributes:

* classes(list): List of the class names.
* class_to_idx(dict): map from class_name to class_index.
* imgs(list): List of (image_path, class_index) tuples

Example:

train_dir = './data/celebA_train'
train_loader = ImageFolder(train_dir).set_attrs(batch_size=batch_size, shuffle=True)
for batch_idx, (x_, target) in enumerate(train_loader):
    ...
class jittor.dataset.MNIST(data_root='/home/zwy/.cache/jittor/dataset/mnist_data/', train=True, download=True, batch_size=16, shuffle=False, transform=None)[源代码]

Jittor’s own class for loading MNIST dataset.

Args:

[in] data_root(str): your data root.
[in] train(bool): choose model train or val.
[in] download(bool): Download data automatically if download is Ture.
[in] batch_size(int): Data batch size.
[in] shuffle(bool): Shuffle data if true.
[in] transform(jittor.transform): transform data.

Example:

from jittor.dataset.mnist import MNIST
train_loader = MNIST(train=True).set_attrs(batch_size=16, shuffle=True)
for i, (imgs, target) in enumerate(train_loader):
    ...
download_url()[源代码]

Download mnist data set function, this function will be called when download is True.

class jittor.dataset.VOC(data_root='/home/zwy/.cache/jittor/dataset/voc/', split='train')[源代码]

Jittor’s own class for loading VOC dataset.

Args:

[in] data_root(str): your data root.
[in] split(str): which split you want to use, train or val.

Attribute:

NUM_CLASSES: Number of total categories, default is 21.

Example:

from jittor.dataset.voc import VOC
train_loader = VOC(data_root='...').set_attrs(batch_size=16, shuffle=True)
for i, (imgs, target) in enumerate(train_loader):
    ...
NUM_CLASSES = 21