jittor.dataset.dataset 源代码

# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved. 
# Maintainers: 
#     Meng-Hao Guo <guomenghao1997@gmail.com>
#     Dun Liang <randonlang@gmail.com>. 
# 
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import numpy as np
from urllib import request
import gzip
import pickle
import os
from jittor.dataset.utils import get_random_list, get_order_list, collate_batch, HookTimer
from collections.abc import Sequence, Mapping
import pathlib
from PIL import Image
import multiprocessing as mp
import signal
from jittor_utils import LOG
import jittor as jt
import time
import jittor_utils as jit_utils

dataset_root = os.path.join(jit_utils.home(), ".cache", "jittor", "dataset")
mp_log_v = os.environ.get("mp_log_v", 0) 
mpi = jt.mpi
img_open_hook = HookTimer(Image, "open")
CHECK_MEMORY = int(os.environ.get("CHECK_MEMORY", "0"))

if os.name == "nt":
    from multiprocessing import shared_memory
    class RingBuffer:
        def __init__(self, size, shm=None):
            for i in range(100):
                if (1<<i) >= size: break
            size = 1<<i
            init = False
            if shm is None:
                init = True
                shm = shared_memory.SharedMemory(create=True, size=size+1024)
            rb = jt.core.RingBuffer(size, id(shm.buf), init)
            self.size = size
            self.shm = shm
            self.rb = rb

        def __reduce__(self):
            return (RingBuffer, (self.size, self.shm))
            
        def __del__(self):
            del self.rb
            del self.shm

        def push(self, obj): self.send(obj)
        def pop(self): return self.recv()
        def send(self, obj): self.rb.push(obj)
        def recv(self): return self.rb.pop()
        def clear(self): return self.rb.clear()
        def stop(self): return self.rb.stop()
        def is_stop(self): return self.rb.is_stop()
        def total_pop(self): return self.rb.total_pop()
        def total_push(self): return self.rb.total_push()
        def __repr__(self): return repr(self.rb)
        def keep_numpy_array(self, keep): self.rb.keep_numpy_array(keep)

    jt.RingBuffer = RingBuffer

class Worker:
    def __init__(self, target, args, buffer_size, keep_numpy_array=False):
        self.buffer = jt.RingBuffer(buffer_size)
        self.buffer.keep_numpy_array(keep_numpy_array)

        self.status = mp.Array('f', 5, lock=False)
        self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
        self.p.daemon = True
        self.p.start()

[文档] class Dataset(object): ''' 数据集的抽象类。用户需要继承此类并实现其中的 ``__getitem__`` 方法以便遍历数据。 参数: - batch_size (int, optional): 批大小。默认值: 16 - shuffle (bool, optional): 每个epoch是否进行随机打乱。默认值: False - drop_last (bool, optional): 若为True, 则可能丢弃每个epoch中最后不足batch_size的数据。默认值: False - num_workers (int, optional): 用于加载数据的工作进程数。默认值: 0 - buffer_size (int, optional): 每个工作进程的缓存大小(字节)。默认值: 512MB(即512 * 1024 * 1024) - keep_numpy_array (bool, optional): 返回numpy数组而非jittor数组。默认值: False - endless (bool, optional): 此数据集是否无限循环地产生数据。默认值: False 属性: - total_len (int): 数据集的总长度 - epoch_id (int): 当前的epoch编号 代码示例: >>> class YourDataset(Dataset): >>> def __init__(self): >>> super().__init__() >>> self.set_attrs(total_len=1024) >>> >>> def __getitem__(self, k): >>> # 返回数据的逻辑 >>> return k, k*k >>> # 实例化你的数据集, 并设置批大小和是否打乱数据 >>> dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) >>> # 遍历数据集 >>> for x, y in dataset: >>> # 处理数据的逻辑 ''' def __init__(self, batch_size = 16, shuffle = False, drop_last = False, num_workers = 0, buffer_size = 512*1024*1024, stop_grad = True, keep_numpy_array = False, endless = False): super().__init__() if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1': num_workers = 0 self.total_len = None self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last self.num_workers = num_workers self.buffer_size = buffer_size self.stop_grad = stop_grad self.keep_numpy_array = keep_numpy_array self.endless = endless self.epoch_id = 0 self.sampler = None self._disable_workers = False self._shuffle_rng = np.random.default_rng(1) self.dataset = self def __getitem__(self, index): raise NotImplementedError def __batch_len__(self): assert self.total_len >= 0 assert self.batch_size > 0 if self.drop_last: return self.total_len // self.batch_size return (self.total_len-1) // self.batch_size + 1 def __len__(self): return self.__batch_len__()
[文档] def set_attrs(self, **kw): ''' 设置数据集的属性。 参数: - **kw (dict): 数据集的属性字典。 返回: - Dataset: 返回设置了属性的数据集对象 代码示例: >>> dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) ''' for k,v in kw.items(): assert hasattr(self, k), k setattr(self, k, v) self.reset() return self
[文档] def to_jittor(self, batch): ''' 用于将输入变量的列表转换为 Jittor 变量的方法。 参数: - batch (list): 需要转换的输入变量的列表。 返回值: - jt.Var: 一个包含输入变量的列表的 Jittor 变量。 代码示例: >>> batch = [np.array([1,2,3]), np.array([4,5,6])] >>> to_jittor(batch) >>> [jt.Var([1, 2, 3]), jt.Var([4, 5, 6])] ''' if self.keep_numpy_array: return batch if isinstance(batch, jt.Var): return batch to_jt = lambda x: jt.array(x).stop_grad() \ if self.stop_grad else jt.array(x) if isinstance(batch, np.ndarray): return to_jt(batch) if isinstance(batch, dict): new_batch = {} for k,v in batch.items(): new_batch[k] = self.to_jittor(v) return new_batch if not isinstance(batch, (list, tuple)): return batch new_batch = [] for a in batch: if isinstance(a, np.ndarray): new_batch.append(to_jt(a)) else: new_batch.append(self.to_jittor(a)) return new_batch
[文档] def collate_batch(self, batch): ''' 用于将输入变量的列表转换为 Jittor 变量的方法。默认情况下, 它会将输入变量的列表转换为 Jittor 变量的列表。如果输入变量的列表中的任何变量的维度为1, 则会将其从输出变量中去除。 参数: - batch (list): 需要转换的输入变量的列表。 返回值: - jt.Var: 一个包含输入变量的列表的 Jittor 变量。 代码示例: >>> batch = [np.array([1,2,3]), np.array([4,5,6])] >>> collate_batch(batch) >>> jt.Var([[1, 2, 3], [4, 5, 6]]) ''' return collate_batch(batch)
[文档] def terminate(self): ''' 终止数据集的工作进程。 代码示例: >>> dataset.terminate() ''' if hasattr(self, "workers"): for w in self.workers: w.p.terminate()
def _worker_main(self, worker_id, buffer, status): import jittor_utils jt.flags.use_cuda_host_allocator = 0 jittor_utils.cc.init_subprocess() jt.jt_init_subprocess() seed = jt.get_seed() wseed = (seed ^ (worker_id*1167)) ^ 1234 jt.set_global_seed(wseed) # parallel_op_compiler still problematic, # it is not work on ubuntu 16.04. but worked on ubuntu 20.04 # it seems like the static value of parallel compiler # is not correctly init. jt.flags.use_parallel_op_compiler = 0 import time try: gid_obj = self.gid.get_obj() gid_lock = self.gid.get_lock() start = time.time() while True: # get id with gid_lock: while buffer.is_stop() or self.idqueue.is_stop() or \ gid_obj.value >= self.batch_len: self.num_idle.value += 1 self.num_idle_c.notify() self.gidc.wait() self.num_idle.value -= 1 cid = gid_obj.value batch_index_list = self.index_list_numpy[ cid*self.real_batch_size: min(self.real_len, (cid+1)*self.real_batch_size) ].copy() gid_obj.value += 1 with self.idqueue_lock: self.idqueue.push(worker_id) now = time.time() other_time = now - start start = now # load and transform data batch = [] if mp_log_v: print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)) for i in batch_index_list: batch.append(self[i]) batch = self.collate_batch(batch) now = time.time() data_time = now - start start = now # send data to main process if mp_log_v: print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer) try: buffer.send(batch) except: if buffer.is_stop(): continue raise now = time.time() send_time = now - start start = now status[0], status[1], status[2], status[3], status[4] = \ other_time, data_time, send_time, \ other_time + data_time + send_time, \ img_open_hook.duration img_open_hook.duration = 0.0 except: import traceback line = traceback.format_exc() print(line) os.kill(os.getppid(), signal.SIGINT) exit(0) def display_worker_status(self): if not hasattr(self, "workers"): return msg = [""] msg.append(f"progress:{self.batch_id}/{self.batch_len}") msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}") msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}") msg.append(f"last 10 workers: {self.last_ids}") msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)") for i in range(self.num_workers): w = self.workers[i] s = w.status msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}") LOG.i('\n'.join(msg)) def _stop_all_workers(self): # stop workers for w in self.workers: w.buffer.stop() self.idqueue.stop() # wait until all workers idle if self.num_idle.value < self.num_workers: with self.gid.get_lock(): self.gid.get_obj().value = self.batch_len if mp_log_v: print("idle num", self.num_idle.value) while self.num_idle.value < self.num_workers: self.num_idle_c.wait() if mp_log_v: print("idle num", self.num_idle.value) # clean workers' buffer for w in self.workers: w.buffer.clear() self.idqueue.clear() self.gid.value = 0 def _init_workers(self, index_list): jt.migrate_all_to_cpu() jt.clean() jt.gc() self.index_list = mp.Array('i', self.real_len, lock=False) workers = [] # get worker id self.idqueue = jt.RingBuffer(2048) self.idqueue_lock = mp.Lock() # global token index self.gid = mp.Value('i', self.batch_len) self.gid.value = 0 # global token index condition self.gidc = mp.Condition(self.gid.get_lock()) # number of idle workers self.num_idle = mp.Value('i', 0, lock=False) # number of idle workers condition self.num_idle_c = mp.Condition(self.gid.get_lock()) self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list) self.index_list_numpy[:] = index_list for i in range(self.num_workers): w = Worker(target=self._worker_main, args=(i,), buffer_size=self.buffer_size, keep_numpy_array=self.keep_numpy_array) workers.append(w) self.workers = workers def reset(self): if not hasattr(self, "workers"): return self._stop_all_workers() self.terminate() del self.index_list del self.idqueue del self.idqueue_lock del self.gid del self.gidc del self.num_idle del self.num_idle_c del self.workers del self.index_list_numpy def __del__(self): if mp_log_v: print("dataset deleted") try: self.terminate() except: pass def __deepcopy__(self, memo=None, _nil=[]): from copy import deepcopy if memo is None: memo = {} d = id(self) y = memo.get(d, _nil) if y is not _nil: return y obj = self.__class__.__new__(self.__class__) memo[d] = id(obj) exclude_key = {"index_list", "idqueue", "idqueue_lock", "gid", "gidc", "num_idle", "num_idle_c", "workers", "index_list_numpy", "dataset", "idqueue", "idqueue_lock"} for k,v in self.__dict__.items(): if k in exclude_key: continue obj.__setattr__(k, deepcopy(v)) obj.dataset = obj return obj def __real_len__(self): if self.total_len is None: self.total_len = len(self) return self.total_len def _get_index_list(self): if self.total_len is None: self.total_len = len(self) # maybe rewrite by sampler total_len = self.total_len if self.sampler: index_list = list(self.sampler.__iter__()) total_len = len(index_list) # check is not batch sampler if len(index_list): assert not isinstance(index_list[0], (list,tuple)), "Batch sampler not support yet." elif self.shuffle == False: index_list = get_order_list(self.total_len) else: # using _shuffle_rng to generate multiprocess # consist shuffle list # index_list = get_random_list(self.total_len) index_list = self._shuffle_rng.permutation(range(self.total_len)) # scatter index_list for all mpi process # scatter rule: # batch 1 batch 2 # [........] [........] ... # 00011122 00011122 # if last batch is smaller than world_size # pad to world_size # last batch # [.] -> [012] if jt.in_mpi: world_size = mpi.world_size() world_rank = mpi.world_rank() index_list = np.int32(index_list) # TODO: mpi broadcast in subprocess has bug, fix it # mpi.broadcast(index_list, 0) assert self.batch_size >= world_size, \ f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" real_batch_size = (self.batch_size-1) // world_size + 1 if real_batch_size * world_size != self.batch_size: LOG.w("Batch size is not divisible by MPI world size, " "The distributed version may be different from " "the single-process version.") fix_batch = total_len // self.batch_size last_batch = total_len - fix_batch * self.batch_size fix_batch_l = index_list[0:fix_batch*self.batch_size] \ .reshape(-1,self.batch_size) fix_batch_l = fix_batch_l[ :,real_batch_size*world_rank:real_batch_size*(world_rank+1)] real_batch_size = fix_batch_l.shape[1] fix_batch_l = fix_batch_l.flatten() if not self.drop_last and last_batch > 0: last_batch_l = index_list[-last_batch:] real_last_batch = (last_batch-1)//world_size+1 l = real_last_batch * world_rank r = l + real_last_batch if r > last_batch: r = last_batch l = r-real_last_batch index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) else: index_list = fix_batch_l self.real_len = len(index_list) self.real_batch_size = real_batch_size # assert total_len // self.batch_size == \ # self.real_len // self.real_batch_size, f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" # print(f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}") # print("mpi dataset init ") else: self.real_len = len(index_list) self.real_batch_size = self.batch_size if self.drop_last: self.batch_len = self.real_len // self.real_batch_size else: self.batch_len = (self.real_len-1) // self.real_batch_size + 1 return index_list def _epochs(self): if self.endless: while True: yield self.epoch_id += 1 else: yield def __iter__(self): if self._disable_workers: self.num_workers = 0 index_list = self._get_index_list() if not hasattr(self, "workers") and self.num_workers: self._init_workers(index_list) self.last_ids = [-1] * 10 if self.num_workers: start = time.time() self.batch_time = 0 gid_obj = self.gid.get_obj() gid_lock = self.gid.get_lock() for _ in self._epochs(): with gid_lock: if self.num_idle.value: self.gidc.notify_all() for i in range(self.batch_len): if self.num_idle.value: with gid_lock: if self.num_idle.value and \ gid_obj.value >= self.batch_len: index_list = self._get_index_list() self.index_list_numpy[:] = index_list gid_obj.value = 0 self.gidc.notify_all() # get which worker has this batch worker_id = self.idqueue.pop() now = time.time() self.wait_time = now - start start = now self.last_ids[i%10] = worker_id self.batch_id = i w = self.workers[worker_id] if mp_log_v: print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer) batch = w.buffer.recv() now = time.time() self.recv_time = now - start start = now if mp_log_v: print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ]) batch = self.to_jittor(batch) now = time.time() self.to_jittor_time = now - start start = now yield batch now = time.time() self.batch_time = now - start start = now if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0: jt.display_memory_info() else: for _ in self._epochs(): self.batch_id = 0 batch_data = [] for idx in index_list: batch_data.append(self[int(idx)]) if len(batch_data) == self.real_batch_size: batch_data = self.collate_batch(batch_data) tmp = batch_data batch_data = self.to_jittor(batch_data) # breakpoint() yield batch_data self.batch_id += 1 if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0: jt.display_memory_info() batch_data = [] # depend on drop_last if not self.drop_last and len(batch_data) > 0: batch_data = self.collate_batch(batch_data) batch_data = self.to_jittor(batch_data) self.batch_id += 1 yield batch_data
[文档] def DataLoader(dataset: Dataset, *args, **kargs): ''' DataLoader结合了数据集和采样器, 并在给定的数据集上提供可迭代性。 参数: - dataset(Dataset): 需要进行封装的数据集。 - args: 数据集属性设置的可变位置参数。 - kargs: 数据集属性设置的可变关键字参数, 使用字典的形式传递参数值。 返回值: 封装好的数据加载器( ``DataLoader`` )。 代码示例: >>> from jittor.dataset.cifar import CIFAR10 >>> from jittor.dataset import DataLoader >>> train_dataset = CIFAR10() >>> dataloader = DataLoader(train_dataset, batch_size=8) >>> for batch_idx, (x_, target) in enumerate(dataloader): >>> # 处理每个批次的数据 ''' return dataset.set_attrs(*args, **kargs)
[文档] class ImageFolder(Dataset): ''' 从目录中加载图像及其标签用于图像分类的数据集。 数据集的目录结构应如下所示: * root/label1/img1.png * root/label1/img2.png * ... * root/label2/img1.png * root/label2/img2.png * ... 参数: - root (str): 包含图像和标签子目录的根目录的路径 - transform (callable, optional): 用于对样本进行转换的optional转换操作(例如, 数据增强)。默认值: None 属性: - classes (list): 类名的列表 - class_to_idx (dict): 从类名映射到类索引的字典 - imgs (list): 包含(image_path, class_index)元组的列表 代码示例: >>> from jittor.dataset import ImageFolder >>> dataset = ImageFolder(root="path_to_your_dataset") ''' def __init__(self, root, transform=None): super().__init__() self.root = root self.transform = transform self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()]) self.class_to_idx = {v:k for k,v in enumerate(self.classes)} self.imgs = [] image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')) for i, class_name in enumerate(self.classes): class_dir = os.path.join(root, class_name) for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)): for fname in sorted(fnames): if os.path.splitext(fname)[-1].lower() in image_exts: path = os.path.join(class_dir, fname) self.imgs.append((path, i)) LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.") self.set_attrs(total_len=len(self.imgs)) def __getitem__(self, k): with open(self.imgs[k][0], 'rb') as f: img = Image.open(f).convert('RGB') if self.transform: img = self.transform(img) return img, self.imgs[k][1]
[文档] class VarDataset(Dataset): ''' 使用 Var 对象直接创建数据集的类, TensorDataset 是 VarDataset 的别名。这个类允许用户直接从 Jittor 变量中创建数据集, 而无需对数据执行任何预处理。数据集中的每个元素都是根据相应的索引从给定的 Jittor 变量中提取的。所有输入变量的第一个维度长度必须相等, 否则创建 VarDataset 时会触发错误。 参数: - *args (jt.Var): 一个或多个 Jittor 变量。所有变量的长度必须相同, 且变量的维度数可以是任意的。这些变量将会并行被索引, 以创建数据集中的条目。 代码示例: >>> import jittor as jt >>> from jittor.dataset import VarDataset >>> x = jt.array([1,2,3]) >>> y = jt.array([4,5,6]) >>> z = jt.array([7,8,9]) >>> dataset = VarDataset(x, y, z) >>> dataset.set_attrs(batch_size=1) >>> for a, b, c in dataset: >>> print(a, b, c) >>> # 1, 4, 7 >>> # 2, 5, 8 >>> # 3, 6, 9 ''' def __init__(self, *args): super().__init__() self.args = args self._disable_workers = True assert len(args), "At lease one args" l = len(args[0]) for a in args: assert l == len(a), "Len should be the same" self.set_attrs(total_len=l) def __getitem__(self, idx): return [ a[idx] for a in self.args ]
[文档] def collate_batch(self, batch): ''' 用于将输入的Jittor变量的列表转换为特定格式 Jittor 变量的方法。 参数: - batch (list of jt.Var): 需要转换的输入变量的列表。 返回值: - jt.Var: 一个包含输入变量的列表的 Jittor 变量。 代码示例: >>> batch = [jt.array([1,2,3]), jt.array([4,5,6])] >>> collate_batch(batch) >>> jt.Var([[1, 2, 3], [4, 5, 6]]) ''' b = collate_batch(batch) for i in range(len(self.args)): x = b[i] if jt.is_var(self.args[i]) and self.args[i].ndim == 1: x.assign(x.squeeze(-1)) return b
TensorDataset = VarDataset