jittor.dataset.sampler 源代码

# ***************************************************************
# Copyright (c) 2022 Jittor. All Rights Reserved. 
# Maintainers: 
#     Hao-Yang Peng
#     Dun Liang <randonlang@gmail.com>. 
# 
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from .dataset import Dataset
import numpy as np
from PIL import Image


[文档]class Sampler(): def __init__(self, dataset): self.dataset = dataset # MUST set sampler here dataset.sampler = self def __iter__(self): raise NotImplementedError def __len__(self): raise NotImplementedError
[文档]class SequentialSampler(Sampler): def __init__(self, dataset): # MUST set sampler here dataset.sampler = self self.dataset = dataset def __iter__(self): return iter(range(self.dataset.__real_len__())) def __len__(self): return self.dataset.__real_len__()
[文档]class RandomSampler(Sampler): def __init__(self, dataset, replacement=False, num_samples=None): # MUST set sampler here dataset.sampler = self self.dataset = dataset self.rep = replacement self._num_samples = num_samples @property def num_samples(self): if self._num_samples is None: return self.dataset.__real_len__() return self._num_samples def __len__(self): return self.num_samples def __iter__(self): n = self.dataset.__real_len__() if self.rep: return iter(np.random.randint(low=0, high=n, size=(self.num_samples,), dtype=np.int64).tolist()) return iter(np.random.permutation(n).tolist())
[文档]class SubsetRandomSampler(Sampler): def __init__(self, dataset, indice): ''' testdataset = TestSamplerDataset() subsetsampler = SubsetRandomSampler(testdataset, (20, 30)) for i, data in enumerate(testdataset): # data between 20 ~ 29 ...... ''' # MUST set sampler here dataset.sampler = self self.dataset = dataset self.indices = indice assert indice[0] >= 0 and indice[1] < dataset.__real_len__() and indice[0] < indice[1] def __iter__(self): return (int(i) + self.indices[0] for i in np.random.permutation(self.indices[1] - self.indices[0])) def __len__(self): return self.indices[1] - self.indices[0]
[文档]class BatchSampler(Sampler): def __init__(self, sampler, batch_size, drop_last): self.sampler = sampler self.batch_size = batch_size self.drop_last = drop_last def __iter__(self): batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if len(batch) > 0 and not self.drop_last: yield batch def __len__(self): if self.drop_last: return len(self.sampler) // self.batch_size else: return (len(self.sampler) + self.batch_size - 1) // self.batch_size