# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers:
# Hao-Yang Peng
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from .dataset import Dataset
import numpy as np
from PIL import Image
[文档]
class Sampler():
'''
所有采样器的基类。每个采样器子类都必须提供一个 :meth:`__iter__()` 方法, 提供对数据集元素的索引或索引列表(批次)进行迭代的方法, 以及一个返回返回迭代器长度的 :meth:`__len__()` 方法。
参数:
- dataset (jittor.dataset.Dataset): 需要采样的数据集。
属性:
- dataset (jittor.dataset.Dataset): 被采样的数据集。
代码示例:
>>> from jittor.dataset import Dataset
>>> from jittor.dataset import Sampler
>>> class MyDataset(Dataset):
... def __len__(self):
... return 100
... def __getitem__(self, index):
... return index
>>> class MySampler(Sampler):
... def __iter__(self):
... return iter(range(len(self.dataset)))
>>> dataset = MyDataset()
>>> sampler = MySampler(dataset)
>>> list(iter(sampler))
[0, 1, 2, ..., 99]
'''
def __init__(self, dataset):
self.dataset = dataset
# MUST set sampler here
dataset.sampler = self
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
[文档]
class SequentialSampler(Sampler):
'''
按顺序对元素进行采样, 采样顺序始终保持不变。
参数:
- dataset (jittor.dataset.Dataset): 需要采样的数据集。
属性:
- dataset (jittor.dataset.Dataset): 被采样的数据集。
代码示例:
>>> import jittor as jt
>>> from jittor.dataset import Dataset
>>> from jittor.dataset import SequentialSampler
>>>
>>> class MyDataset(Dataset):
>>> def __len__(self):
>>> return 5
>>>
>>> def __getitem__(self, index):
>>> return index
>>>
>>> dataset = MyDataset()
>>> sampler = SequentialSampler(dataset)
>>> list(iter(sampler))
[0, 1, 2, 3, 4]
'''
def __init__(self, dataset):
# MUST set sampler here
dataset.sampler = self
self.dataset = dataset
def __iter__(self):
return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()))
def __len__(self):
return self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()
[文档]
class RandomSampler(Sampler):
'''
随机对元素进行采样。
参数:
- dataset (jittor.dataset.Dataset): 需要采样的数据集。
- replacement (bool, optional): 是否采用替换的方式进行采样。False为不替换, 即每个样本仅被采样一次。默认值: False。
- num_samples (int or None, optional): 期望采样的样本总数。如果为None, 则采样数量为数据集的实际长度。默认值: None。
属性:
- dataset (Dataset): 被采样的数据集。
- rep (bool): 采样时是否允许替换。
- _num_samples (int or None): 采样的样本总数。
- _shuffle_rng (np.random.Generator): 随机数生成器, 用于生成随机序列。
代码示例:
>>> import jittor as jt
>>> from jittor.dataset import Dataset
>>> from jittor.dataset import RandomSampler
>>>
>>> class MyDataset(Dataset):
>>> def __len__(self):
>>> return 5
>>>
>>> def __getitem__(self, index):
>>> return index
>>>
>>>
>>> dataset = MyDataset()
>>> sampler = RandomSampler(dataset)
>>> list(iter(sampler))
[2, 1, 0, 3, 4]
'''
def __init__(self, dataset, replacement=False, num_samples=None):
# MUST set sampler here
dataset.sampler = self
self.dataset = dataset
self.rep = replacement
self._num_samples = num_samples
self._shuffle_rng = np.random.default_rng(1)
@property
def num_samples(self):
if self._num_samples is None:
return self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()
return self._num_samples
def __len__(self):
return self.num_samples
def __iter__(self):
n = self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__()
if self.rep:
return iter(self._shuffle_rng.integers(low=0, high=n, size=(self.num_samples,), dtype=np.int64).tolist())
return iter(self._shuffle_rng.permutation(n).tolist())
[文档]
class SkipFirstBatchesSampler(Sampler):
'''
在数据采样过程中, 跳过前N个批次的数据采样器。
参数:
- sampler (jittor.dataset.Sampler): 被包装的原始采样器。被包装的采样器 ``sampler`` 必须是 ``Sampler`` 的实例或其子类之一。
- num_skip_batches (int): 要跳过的批次数量
属性:
- sampler (jittor.dataset.Sampler): 被包装的原始采样器。
- num_skip_batches (int): 要跳过的批次数量。
代码示例:
>>> # 假设有一个已经创建好的采样器
>>> original_sampler = ...
>>> # 创建一个新的采样器实例, 跳过开始的2个批次
>>> skip_sampler = SkipFirstBatchesSampler(original_sampler, num_skip_batches=2)
>>> for batch in skip_sampler:
>>> # 这里开始迭代的批次是原采样器的第三个批次
'''
def __init__(self, sampler, num_skip_batches):
# MUST set sampler here
sampler.dataset.sampler = self
self.sampler = sampler
self.num_skip_batches = num_skip_batches
def __len__(self):
return len(self.sampler) - self.num_skip_batches
def __iter__(self):
return iter(list(iter(self.sampler))[self.num_skip_batches:])
[文档]
class SubsetRandomSampler(Sampler):
'''
从给定的索引列表中随机抽取元素。
参数:
- dataset (jittor.dataset.Dataset): 需要采样的数据集。
- indices (tuple[int]): 用于采样的索引范围, 包括开始索引和结束索引, 形如(start_index, end_index)。索引范围`indices`是左闭右开区间, 即包括开始索引但不包括结束索引。开始索引需大于等于0, 结束索引需小于数据集的长度且大于开始索引。如果索引范围不满足要求, 则抛出AssertionError。
代码示例:
>>> import jittor as jt
>>> from jittor.dataset import Dataset
>>> from jittor.dataset import SubsetRandomSampler
>>> class MyDataset(Dataset):
>>> def __len__(self):
>>> return 10
>>>
>>> def __getitem__(self, index):
>>> return index
>>>
>>>
>>> dataset = MyDataset()
>>> sampler = SubsetRandomSampler(dataset, (3, 7))
>>> list(iter(sampler))
[4, 3, 5, 6]
'''
def __init__(self, dataset, indice):
# MUST set sampler here
dataset.sampler = self
self.dataset = dataset
self.indices = indice
dlen = dataset.__real_len__() if hasattr(dataset,'''__real_len__''') else dataset.__len__()
assert indice[0] >= 0 and indice[1] < dlen and indice[0] < indice[1]
def __iter__(self):
return (int(i) + self.indices[0] for i in np.random.permutation(self.indices[1] - self.indices[0]))
def __len__(self):
return self.indices[1] - self.indices[0]
[文档]
class BatchSampler(Sampler):
'''
将数据集的索引分成多个批次, 每个批次包含一个索引列表。
参数:
- sampler (jittor.dataset.Sampler): 用于提供索引的Sampler对象。
- batch_size (int): 每批样本的大小。
- drop_last (bool): 若设置为True, 则可能会丢弃最后一批, 如果它的大小小于 ``batch_size`` 。
代码示例:
>>> import jittor as jt
>>> from jittor.dataset import Dataset
>>> from jittor.dataset import BatchSampler, SequentialSampler
>>>
>>> class MyDataset(Dataset):
>>> def __len__(self):
>>> return 10
>>>
>>> def __getitem__(self, index):
>>> return index
>>>
>>> dataset = MyDataset()
>>> sampler = BatchSampler(SequentialSampler(dataset), 3, True)
>>> list(iter(sampler))
>>> [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
>>> sampler = BatchSampler(SequentialSampler(dataset), 3, False)
>>> list(iter(sampler))
>>> [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
'''
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size