jittor.dataset

这里是Jittor的数据集模块的API文档,您可以通过from jittor import dataset来获取该模块。

class jittor.dataset.BatchSampler(sampler, batch_size, drop_last)[源代码]

将数据集的索引分成多个批次, 每个批次包含一个索引列表。

参数:
  • sampler (jittor.dataset.Sampler): 用于提供索引的Sampler对象。

  • batch_size (int): 每批样本的大小。

  • drop_last (bool): 若设置为True, 则可能会丢弃最后一批, 如果它的大小小于 batch_size

代码示例:
>>> import jittor as jt
>>> from jittor.dataset import Dataset
>>> from jittor.dataset import BatchSampler, SequentialSampler
>>>
>>> class MyDataset(Dataset):
>>>      def __len__(self):
>>>          return 10
>>>
>>>      def __getitem__(self, index):
>>>          return index
>>>      
>>> dataset = MyDataset()
>>> sampler = BatchSampler(SequentialSampler(dataset), 3, True)
>>> list(iter(sampler))
>>> [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
>>> sampler = BatchSampler(SequentialSampler(dataset), 3, False) 
>>> list(iter(sampler))
>>> [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
class jittor.dataset.CIFAR10(root='/home/zwy/.cache/jittor/dataset/cifar_data/', train=True, transform=None, target_transform=None, download=True)[源代码]

CIFAR10 数据集。

参数:
  • root (str, optional): 数据集的根目录, 其中应该包含或将要保存包含 cifar-10-batches-py 的目录, 如果设置了download为True, 则会自动下载并解压数据集, 默认值: dataset_root + '/cifar_data/'

  • train (bool, optional): 如果为True, 则创建训练集的数据集对象;如果为False, 则创建测试集的数据集对象, 默认值: True。

  • transform (callable, optional): 一个用于图像变换的函数或可调用对象, 它接收一个PIL图片, 并返回经过变换的图片。例如, transforms.RandomCrop , 默认值: None。

  • target_transform (callable, optional): 一个用于标签变换的函数或可调用对象, 它接收目标标签, 并返回转换后的标签。默认值: None

  • download (bool, optional): 如果为True, 则从互联网下载数据集并将其保存到root目录。如果数据集已经下载, 不会再次下载, 默认值: True。

属性:
  • root (str): 数据集的根目录。

  • data (numpy.ndarray): 图像数据, 形状为(50000, 3, 32, 32)或(10000, 3, 32, 32)。

  • targets (numpy.ndarray): 标签数据, 形状为(50000,)或(10000,)。

  • classes (list): 类别名称列表, 形状为(10,)。

  • class_to_idx (dict): 类别名称到索引的映射:({‘airplane’: 0, ‘automobile’: 1, ‘bird’: 2, ‘cat’: 3, ‘deer’: 4, ‘dog’: 5, ‘frog’: 6, ‘horse’: 7,’ship’: 8, ‘truck’: 9})。

代码示例:
>>> from jittor.dataset.cifar import CIFAR10
>>> dataset = CIFAR10()
class jittor.dataset.CIFAR100(root='/home/zwy/.cache/jittor/dataset/cifar_data/', train=True, transform=None, target_transform=None, download=True)[源代码]

CIFAR100 数据集, 是 CIFAR10 数据集的子类。

参数:
  • split (str, optional) - 指定数据集分割的类型, ‘train’ 表示训练集, ‘test’ 表示测试集, 默认值: ‘train’ 。

  • transform (callable, optional) - 一个函数或transform对象, 用于对样本进行处理, 默认值: None

  • target_transform (callable, optional) - 一个函数或transform对象, 用于对标签进行处理, 默认值: None

  • download (bool, optional) - 是否下载数据集, 如果数据集未下载, 则设为 True, 默认值: False

  • batch_size (int, optional) - 每个batch中的样本数, 默认值: 1。

  • shuffle (bool, optional) - 是否在每个epoch开始时打乱数据, 默认值: False

  • num_workers (int, optional) - 加载数据时使用的子进程数量, 默认值: 0。

属性:
  • base_folder (str) - CIFAR100数据集解压后的文件夹名称。

  • url (str) - CIFAR100数据集下载链接。

  • filename (str) - CIFAR100数据集压缩文件的文件名。

  • tgz_md5 (str) - CIFAR100压缩文件的MD5校验码。

  • train_list (list of list) - 训练集文件的名称列表及其MD5校验码。

  • test_list (list of list) - 测试集文件的名称列表及其MD5校验码。

  • meta (dict) - 包含标签和层级结构信息的元数据字典。

代码示例:
>>> from jittor.dataset.cifar import CIFAR100
>>> dataset = CIFAR100(split='train')
jittor.dataset.DataLoader(dataset: Dataset, *args, **kargs)[源代码]

DataLoader结合了数据集和采样器, 并在给定的数据集上提供可迭代性。

参数:
  • dataset(Dataset): 需要进行封装的数据集。

  • args: 数据集属性设置的可变位置参数。

  • kargs: 数据集属性设置的可变关键字参数, 使用字典的形式传递参数值。

返回值:

封装好的数据加载器( DataLoader )。

代码示例:
>>> from jittor.dataset.cifar import CIFAR10 
>>> from jittor.dataset import DataLoader
>>> train_dataset = CIFAR10()
>>> dataloader = DataLoader(train_dataset, batch_size=8)
>>> for batch_idx, (x_, target) in enumerate(dataloader):
>>>     # 处理每个批次的数据
class jittor.dataset.Dataset(batch_size=16, shuffle=False, drop_last=False, num_workers=0, buffer_size=536870912, stop_grad=True, keep_numpy_array=False, endless=False)[源代码]

数据集的抽象类。用户需要继承此类并实现其中的 __getitem__ 方法以便遍历数据。

参数:
  • batch_size (int, optional): 批大小。默认值: 16

  • shuffle (bool, optional): 每个epoch是否进行随机打乱。默认值: False

  • drop_last (bool, optional): 若为True, 则可能丢弃每个epoch中最后不足batch_size的数据。默认值: False

  • num_workers (int, optional): 用于加载数据的工作进程数。默认值: 0

  • buffer_size (int, optional): 每个工作进程的缓存大小(字节)。默认值: 512MB(即512 * 1024 * 1024)

  • keep_numpy_array (bool, optional): 返回numpy数组而非jittor数组。默认值: False

  • endless (bool, optional): 此数据集是否无限循环地产生数据。默认值: False

属性:
  • total_len (int): 数据集的总长度

  • epoch_id (int): 当前的epoch编号

代码示例:
>>> class YourDataset(Dataset):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.set_attrs(total_len=1024)
>>> 
>>>     def __getitem__(self, k):
>>>         # 返回数据的逻辑
>>>         return k, k*k
>>> # 实例化你的数据集, 并设置批大小和是否打乱数据
>>> dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)
>>> # 遍历数据集
>>> for x, y in dataset:
>>>     # 处理数据的逻辑
collate_batch(batch)[源代码]

用于将输入变量的列表转换为 Jittor 变量的方法。默认情况下, 它会将输入变量的列表转换为 Jittor 变量的列表。如果输入变量的列表中的任何变量的维度为1, 则会将其从输出变量中去除。

参数:
  • batch (list): 需要转换的输入变量的列表。

返回值:
  • jt.Var: 一个包含输入变量的列表的 Jittor 变量。

代码示例:
>>> batch = [np.array([1,2,3]), np.array([4,5,6])]
>>> collate_batch(batch)
>>> jt.Var([[1, 2, 3], [4, 5, 6]])
set_attrs(**kw)[源代码]

设置数据集的属性。

参数:
  • **kw (dict): 数据集的属性字典。

返回:
  • Dataset: 返回设置了属性的数据集对象

代码示例:
>>> dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)
terminate()[源代码]

终止数据集的工作进程。

代码示例:
>>> dataset.terminate()
to_jittor(batch)[源代码]
用于将输入变量的列表转换为 Jittor 变量的方法。
参数:
  • batch (list): 需要转换的输入变量的列表。

返回值:
  • jt.Var: 一个包含输入变量的列表的 Jittor 变量。

代码示例:
>>> batch = [np.array([1,2,3]), np.array([4,5,6])]
>>> to_jittor(batch)
>>> [jt.Var([1, 2, 3]), jt.Var([4, 5, 6])]
class jittor.dataset.ImageFolder(root, transform=None)[源代码]

从目录中加载图像及其标签用于图像分类的数据集。

数据集的目录结构应如下所示:
  • root/label1/img1.png

  • root/label1/img2.png

  • root/label2/img1.png

  • root/label2/img2.png

参数:
  • root (str): 包含图像和标签子目录的根目录的路径

  • transform (callable, optional): 用于对样本进行转换的optional转换操作(例如, 数据增强)。默认值: None

属性:
  • classes (list): 类名的列表

  • class_to_idx (dict): 从类名映射到类索引的字典

  • imgs (list): 包含(image_path, class_index)元组的列表

代码示例:
>>> from jittor.dataset import ImageFolder
>>> dataset = ImageFolder(root="path_to_your_dataset")
class jittor.dataset.MNIST(data_root='/home/zwy/.cache/jittor/dataset/mnist_data/', train=True, download=True, batch_size=16, shuffle=False, transform=None)[源代码]

MNIST 数据集。

参数:
  • data_root (str): 数据根目录的路径, 默认值: ‘dataset_root/mnist_data/’

  • train (bool, optional): 选择是训练模式还是验证模式, 默认值: True。

  • download (bool, optional): 如果设置为 True, 则会自动下载数据。默认值: True。

  • batch_size (int, optional): 数据批次大小, 默认值: 16。

  • shuffle (bool, optional): 若为 True, 则会在每个 epoch 对数据进行打乱, 默认值: False。

  • transform (jittor.transform, optional): 数据的转换操作, 默认值: None。

属性:
  • data_root (str): 数据根目录的路径。

  • is_train (bool): 选择是训练模式还是验证模式。

  • transform (jittor.transform): 数据的转换操作。

  • batch_size (int): 数据批次大小。

  • shuffle (bool): 是否打乱数据。

  • mnist (dict): mnist 数据集。

  • total_len (int): 数据集的长度。

代码示例:
>>> from jittor.dataset.mnist import MNIST
>>> train_loader = MNIST(train=True, batch_size=16, shuffle=True)
>>> for i, (imgs, target) in enumerate(train_loader):
>>> ...     # 处理训练数据
class jittor.dataset.RandomSampler(dataset, replacement=False, num_samples=None)[源代码]

随机对元素进行采样。

参数:
  • dataset (jittor.dataset.Dataset): 需要采样的数据集。

  • replacement (bool, optional): 是否采用替换的方式进行采样。False为不替换, 即每个样本仅被采样一次。默认值: False。

  • num_samples (int or None, optional): 期望采样的样本总数。如果为None, 则采样数量为数据集的实际长度。默认值: None。

属性:
  • dataset (Dataset): 被采样的数据集。

  • rep (bool): 采样时是否允许替换。

  • _num_samples (int or None): 采样的样本总数。

  • _shuffle_rng (np.random.Generator): 随机数生成器, 用于生成随机序列。

代码示例:
>>> import jittor as jt
>>> from jittor.dataset import Dataset
>>> from jittor.dataset import RandomSampler
>>>
>>> class MyDataset(Dataset):
>>>      def __len__(self):
>>>          return 5
>>>
>>>      def __getitem__(self, index):
>>>          return index
>>>
>>>
>>> dataset = MyDataset()
>>> sampler = RandomSampler(dataset)
>>> list(iter(sampler))
[2, 1, 0, 3, 4]
class jittor.dataset.Sampler(dataset)[源代码]

所有采样器的基类。每个采样器子类都必须提供一个 __iter__() 方法, 提供对数据集元素的索引或索引列表(批次)进行迭代的方法, 以及一个返回返回迭代器长度的 __len__() 方法。

参数:
  • dataset (jittor.dataset.Dataset): 需要采样的数据集。

属性:
  • dataset (jittor.dataset.Dataset): 被采样的数据集。

代码示例:
>>> from jittor.dataset import Dataset
>>> from jittor.dataset import Sampler
>>> class MyDataset(Dataset):
...     def __len__(self):
...         return 100
...     def __getitem__(self, index):
...         return index
>>> class MySampler(Sampler):
...     def __iter__(self):
...         return iter(range(len(self.dataset)))
>>> dataset = MyDataset()
>>> sampler = MySampler(dataset)
>>> list(iter(sampler))
[0, 1, 2, ..., 99]
class jittor.dataset.SequentialSampler(dataset)[源代码]

按顺序对元素进行采样, 采样顺序始终保持不变。

参数:
  • dataset (jittor.dataset.Dataset): 需要采样的数据集。

属性:
  • dataset (jittor.dataset.Dataset): 被采样的数据集。

代码示例:
>>> import jittor as jt
>>> from jittor.dataset import Dataset
>>> from jittor.dataset import SequentialSampler
>>>
>>> class MyDataset(Dataset):
>>>      def __len__(self):
>>>          return 5
>>>
>>>      def __getitem__(self, index):
>>>          return index
>>> 
>>> dataset = MyDataset()
>>> sampler = SequentialSampler(dataset)
>>> list(iter(sampler))
[0, 1, 2, 3, 4]
class jittor.dataset.SkipFirstBatchesSampler(sampler, num_skip_batches)[源代码]

在数据采样过程中, 跳过前N个批次的数据采样器。

参数:
  • sampler (jittor.dataset.Sampler): 被包装的原始采样器。被包装的采样器 sampler 必须是 Sampler 的实例或其子类之一。

  • num_skip_batches (int): 要跳过的批次数量

属性:
  • sampler (jittor.dataset.Sampler): 被包装的原始采样器。

  • num_skip_batches (int): 要跳过的批次数量。

代码示例:
>>> # 假设有一个已经创建好的采样器
>>> original_sampler = ...
>>> # 创建一个新的采样器实例, 跳过开始的2个批次
>>> skip_sampler = SkipFirstBatchesSampler(original_sampler, num_skip_batches=2)
>>> for batch in skip_sampler:
>>>     # 这里开始迭代的批次是原采样器的第三个批次
class jittor.dataset.SubsetRandomSampler(dataset, indice)[源代码]

从给定的索引列表中随机抽取元素。

参数:
  • dataset (jittor.dataset.Dataset): 需要采样的数据集。

  • indices (tuple[int]): 用于采样的索引范围, 包括开始索引和结束索引, 形如(start_index, end_index)。索引范围`indices`是左闭右开区间, 即包括开始索引但不包括结束索引。开始索引需大于等于0, 结束索引需小于数据集的长度且大于开始索引。如果索引范围不满足要求, 则抛出AssertionError。

代码示例:
>>> import jittor as jt
>>> from jittor.dataset import Dataset
>>> from jittor.dataset import SubsetRandomSampler
>>> class MyDataset(Dataset):
>>>      def __len__(self):
>>>          return 10
>>>
>>>      def __getitem__(self, index):
>>>          return index
>>>
>>>
>>> dataset = MyDataset()
>>> sampler = SubsetRandomSampler(dataset, (3, 7))
>>> list(iter(sampler))
[4, 3, 5, 6]
jittor.dataset.TensorDataset

VarDataset 的别名

class jittor.dataset.VOC(data_root='/home/zwy/.cache/jittor/dataset/voc/', split='train')[源代码]

Pascal VOC 数据集

参数:
  • data_root (str): 数据集的根目录

  • split (str, optional): 选择数据集的子集, ‘train’表示训练集, ‘val’表示验证集。默认值: ‘train’

属性:
  • data_root (str): 数据集的根目录

  • split (str): 数据集的子集

  • image_root (str): 图像文件夹的路径

  • label_root (str): 标签文件夹的路径

  • data_list_path (str): 数据列表文件的路径

  • image_path (list of str): 图像文件的路径列表

  • label_path (list of str): 标签文件的路径列表

代码示例:
>>> from jittor.dataset.voc import VOC
>>> train_loader = VOC(data_root='path/to/VOC').set_attrs(batch_size=16, shuffle=True)
>>> for i, (imgs, target) in enumerate(train_loader):
>>>     # 处理图像和标签
NUM_CLASSES = 21
class jittor.dataset.VarDataset(*args)[源代码]

使用 Var 对象直接创建数据集的类, TensorDataset 是 VarDataset 的别名。这个类允许用户直接从 Jittor 变量中创建数据集, 而无需对数据执行任何预处理。数据集中的每个元素都是根据相应的索引从给定的 Jittor 变量中提取的。所有输入变量的第一个维度长度必须相等, 否则创建 VarDataset 时会触发错误。

参数:
  • *args (jt.Var): 一个或多个 Jittor 变量。所有变量的长度必须相同, 且变量的维度数可以是任意的。这些变量将会并行被索引, 以创建数据集中的条目。

代码示例:
>>> import jittor as jt
>>> from jittor.dataset import VarDataset
>>> x = jt.array([1,2,3])
>>> y = jt.array([4,5,6])
>>> z = jt.array([7,8,9])
>>> dataset = VarDataset(x, y, z)
>>> dataset.set_attrs(batch_size=1)
>>> for a, b, c in dataset:
>>>     print(a, b, c)
>>> #  1, 4, 7
>>> #  2, 5, 8
>>> #  3, 6, 9
collate_batch(batch)[源代码]

用于将输入的Jittor变量的列表转换为特定格式 Jittor 变量的方法。

参数:
  • batch (list of jt.Var): 需要转换的输入变量的列表。

返回值:
  • jt.Var: 一个包含输入变量的列表的 Jittor 变量。

代码示例:
>>> batch = [jt.array([1,2,3]), jt.array([4,5,6])]
>>> collate_batch(batch)
>>> jt.Var([[1, 2, 3], [4, 5, 6]])