jittor.dataset.sampler 源代码

# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved. 
# Maintainers: 
#     Hao-Yang Peng
#     Dun Liang <randonlang@gmail.com>. 
# 
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from .dataset import Dataset
import numpy as np
from PIL import Image


[文档] class Sampler(): ''' 所有采样器的基类。每个采样器子类都必须提供一个 :meth:`__iter__()` 方法, 提供对数据集元素的索引或索引列表(批次)进行迭代的方法, 以及一个返回返回迭代器长度的 :meth:`__len__()` 方法。 参数: - dataset (jittor.dataset.Dataset): 需要采样的数据集。 属性: - dataset (jittor.dataset.Dataset): 被采样的数据集。 代码示例: >>> from jittor.dataset import Dataset >>> from jittor.dataset import Sampler >>> class MyDataset(Dataset): ... def __len__(self): ... return 100 ... def __getitem__(self, index): ... return index >>> class MySampler(Sampler): ... def __iter__(self): ... return iter(range(len(self.dataset))) >>> dataset = MyDataset() >>> sampler = MySampler(dataset) >>> list(iter(sampler)) [0, 1, 2, ..., 99] ''' def __init__(self, dataset): self.dataset = dataset # MUST set sampler here dataset.sampler = self def __iter__(self): raise NotImplementedError def __len__(self): raise NotImplementedError
[文档] class SequentialSampler(Sampler): ''' 按顺序对元素进行采样, 采样顺序始终保持不变。 参数: - dataset (jittor.dataset.Dataset): 需要采样的数据集。 属性: - dataset (jittor.dataset.Dataset): 被采样的数据集。 代码示例: >>> import jittor as jt >>> from jittor.dataset import Dataset >>> from jittor.dataset import SequentialSampler >>> >>> class MyDataset(Dataset): >>> def __len__(self): >>> return 5 >>> >>> def __getitem__(self, index): >>> return index >>> >>> dataset = MyDataset() >>> sampler = SequentialSampler(dataset) >>> list(iter(sampler)) [0, 1, 2, 3, 4] ''' def __init__(self, dataset): # MUST set sampler here dataset.sampler = self self.dataset = dataset def __iter__(self): return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__())) def __len__(self): return self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()
[文档] class RandomSampler(Sampler): ''' 随机对元素进行采样。 参数: - dataset (jittor.dataset.Dataset): 需要采样的数据集。 - replacement (bool, optional): 是否采用替换的方式进行采样。False为不替换, 即每个样本仅被采样一次。默认值: False。 - num_samples (int or None, optional): 期望采样的样本总数。如果为None, 则采样数量为数据集的实际长度。默认值: None。 属性: - dataset (Dataset): 被采样的数据集。 - rep (bool): 采样时是否允许替换。 - _num_samples (int or None): 采样的样本总数。 - _shuffle_rng (np.random.Generator): 随机数生成器, 用于生成随机序列。 代码示例: >>> import jittor as jt >>> from jittor.dataset import Dataset >>> from jittor.dataset import RandomSampler >>> >>> class MyDataset(Dataset): >>> def __len__(self): >>> return 5 >>> >>> def __getitem__(self, index): >>> return index >>> >>> >>> dataset = MyDataset() >>> sampler = RandomSampler(dataset) >>> list(iter(sampler)) [2, 1, 0, 3, 4] ''' def __init__(self, dataset, replacement=False, num_samples=None): # MUST set sampler here dataset.sampler = self self.dataset = dataset self.rep = replacement self._num_samples = num_samples self._shuffle_rng = np.random.default_rng(1) @property def num_samples(self): if self._num_samples is None: return self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__() return self._num_samples def __len__(self): return self.num_samples def __iter__(self): n = self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__() if self.rep: return iter(self._shuffle_rng.integers(low=0, high=n, size=(self.num_samples,), dtype=np.int64).tolist()) return iter(self._shuffle_rng.permutation(n).tolist())
[文档] class SkipFirstBatchesSampler(Sampler): ''' 在数据采样过程中, 跳过前N个批次的数据采样器。 参数: - sampler (jittor.dataset.Sampler): 被包装的原始采样器。被包装的采样器 ``sampler`` 必须是 ``Sampler`` 的实例或其子类之一。 - num_skip_batches (int): 要跳过的批次数量 属性: - sampler (jittor.dataset.Sampler): 被包装的原始采样器。 - num_skip_batches (int): 要跳过的批次数量。 代码示例: >>> # 假设有一个已经创建好的采样器 >>> original_sampler = ... >>> # 创建一个新的采样器实例, 跳过开始的2个批次 >>> skip_sampler = SkipFirstBatchesSampler(original_sampler, num_skip_batches=2) >>> for batch in skip_sampler: >>> # 这里开始迭代的批次是原采样器的第三个批次 ''' def __init__(self, sampler, num_skip_batches): # MUST set sampler here sampler.dataset.sampler = self self.sampler = sampler self.num_skip_batches = num_skip_batches def __len__(self): return len(self.sampler) - self.num_skip_batches def __iter__(self): return iter(list(iter(self.sampler))[self.num_skip_batches:])
[文档] class SubsetRandomSampler(Sampler): ''' 从给定的索引列表中随机抽取元素。 参数: - dataset (jittor.dataset.Dataset): 需要采样的数据集。 - indices (tuple[int]): 用于采样的索引范围, 包括开始索引和结束索引, 形如(start_index, end_index)。索引范围`indices`是左闭右开区间, 即包括开始索引但不包括结束索引。开始索引需大于等于0, 结束索引需小于数据集的长度且大于开始索引。如果索引范围不满足要求, 则抛出AssertionError。 代码示例: >>> import jittor as jt >>> from jittor.dataset import Dataset >>> from jittor.dataset import SubsetRandomSampler >>> class MyDataset(Dataset): >>> def __len__(self): >>> return 10 >>> >>> def __getitem__(self, index): >>> return index >>> >>> >>> dataset = MyDataset() >>> sampler = SubsetRandomSampler(dataset, (3, 7)) >>> list(iter(sampler)) [4, 3, 5, 6] ''' def __init__(self, dataset, indice): # MUST set sampler here dataset.sampler = self self.dataset = dataset self.indices = indice dlen = dataset.__real_len__() if hasattr(dataset,'''__real_len__''') else dataset.__len__() assert indice[0] >= 0 and indice[1] < dlen and indice[0] < indice[1] def __iter__(self): return (int(i) + self.indices[0] for i in np.random.permutation(self.indices[1] - self.indices[0])) def __len__(self): return self.indices[1] - self.indices[0]
[文档] class BatchSampler(Sampler): ''' 将数据集的索引分成多个批次, 每个批次包含一个索引列表。 参数: - sampler (jittor.dataset.Sampler): 用于提供索引的Sampler对象。 - batch_size (int): 每批样本的大小。 - drop_last (bool): 若设置为True, 则可能会丢弃最后一批, 如果它的大小小于 ``batch_size`` 。 代码示例: >>> import jittor as jt >>> from jittor.dataset import Dataset >>> from jittor.dataset import BatchSampler, SequentialSampler >>> >>> class MyDataset(Dataset): >>> def __len__(self): >>> return 10 >>> >>> def __getitem__(self, index): >>> return index >>> >>> dataset = MyDataset() >>> sampler = BatchSampler(SequentialSampler(dataset), 3, True) >>> list(iter(sampler)) >>> [[0, 1, 2], [3, 4, 5], [6, 7, 8]] >>> sampler = BatchSampler(SequentialSampler(dataset), 3, False) >>> list(iter(sampler)) >>> [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] ''' def __init__(self, sampler, batch_size, drop_last): self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size