# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers:
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import numpy as np
from urllib import request
import gzip
import pickle
import os
from jittor.dataset.utils import get_random_list, get_order_list, collate_batch, HookTimer
from collections.abc import Sequence, Mapping
import pathlib
from PIL import Image
import multiprocessing as mp
import signal
from jittor_utils import LOG
import jittor as jt
import time
import jittor_utils as jit_utils
dataset_root = os.path.join(jit_utils.home(), ".cache", "jittor", "dataset")
mp_log_v = os.environ.get("mp_log_v", 0)
mpi = jt.mpi
img_open_hook = HookTimer(Image, "open")
CHECK_MEMORY = int(os.environ.get("CHECK_MEMORY", "0"))
if os.name == "nt":
from multiprocessing import shared_memory
class RingBuffer:
def __init__(self, size, shm=None):
for i in range(100):
if (1<<i) >= size: break
size = 1<<i
init = False
if shm is None:
init = True
shm = shared_memory.SharedMemory(create=True, size=size+1024)
rb = jt.core.RingBuffer(size, id(shm.buf), init)
self.size = size
self.shm = shm
self.rb = rb
def __reduce__(self):
return (RingBuffer, (self.size, self.shm))
def __del__(self):
del self.rb
del self.shm
def push(self, obj): self.send(obj)
def pop(self): return self.recv()
def send(self, obj): self.rb.push(obj)
def recv(self): return self.rb.pop()
def clear(self): return self.rb.clear()
def stop(self): return self.rb.stop()
def is_stop(self): return self.rb.is_stop()
def total_pop(self): return self.rb.total_pop()
def total_push(self): return self.rb.total_push()
def __repr__(self): return repr(self.rb)
def keep_numpy_array(self, keep): self.rb.keep_numpy_array(keep)
jt.RingBuffer = RingBuffer
class Worker:
def __init__(self, target, args, buffer_size, keep_numpy_array=False):
self.buffer = jt.RingBuffer(buffer_size)
self.buffer.keep_numpy_array(keep_numpy_array)
self.status = mp.Array('f', 5, lock=False)
self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
self.p.daemon = True
self.p.start()
[文档]
class Dataset(object):
'''
数据集的抽象类。用户需要继承此类并实现其中的 ``__getitem__`` 方法以便遍历数据。
参数:
- batch_size (int, optional): 批大小。默认值: 16
- shuffle (bool, optional): 每个epoch是否进行随机打乱。默认值: False
- drop_last (bool, optional): 若为True, 则可能丢弃每个epoch中最后不足batch_size的数据。默认值: False
- num_workers (int, optional): 用于加载数据的工作进程数。默认值: 0
- buffer_size (int, optional): 每个工作进程的缓存大小(字节)。默认值: 512MB(即512 * 1024 * 1024)
- keep_numpy_array (bool, optional): 返回numpy数组而非jittor数组。默认值: False
- endless (bool, optional): 此数据集是否无限循环地产生数据。默认值: False
属性:
- total_len (int): 数据集的总长度
- epoch_id (int): 当前的epoch编号
代码示例:
>>> class YourDataset(Dataset):
>>> def __init__(self):
>>> super().__init__()
>>> self.set_attrs(total_len=1024)
>>>
>>> def __getitem__(self, k):
>>> # 返回数据的逻辑
>>> return k, k*k
>>> # 实例化你的数据集, 并设置批大小和是否打乱数据
>>> dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)
>>> # 遍历数据集
>>> for x, y in dataset:
>>> # 处理数据的逻辑
'''
def __init__(self,
batch_size = 16,
shuffle = False,
drop_last = False,
num_workers = 0,
buffer_size = 512*1024*1024,
stop_grad = True,
keep_numpy_array = False,
endless = False):
super().__init__()
if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1':
num_workers = 0
self.total_len = None
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.num_workers = num_workers
self.buffer_size = buffer_size
self.stop_grad = stop_grad
self.keep_numpy_array = keep_numpy_array
self.endless = endless
self.epoch_id = 0
self.sampler = None
self._disable_workers = False
self._shuffle_rng = np.random.default_rng(1)
self.dataset = self
def __getitem__(self, index):
raise NotImplementedError
def __batch_len__(self):
assert self.total_len >= 0
assert self.batch_size > 0
if self.drop_last:
return self.total_len // self.batch_size
return (self.total_len-1) // self.batch_size + 1
def __len__(self):
return self.__batch_len__()
[文档]
def set_attrs(self, **kw):
'''
设置数据集的属性。
参数:
- **kw (dict): 数据集的属性字典。
返回:
- Dataset: 返回设置了属性的数据集对象
代码示例:
>>> dataset = YourDataset().set_attrs(batch_size=256, shuffle=True)
'''
for k,v in kw.items():
assert hasattr(self, k), k
setattr(self, k, v)
self.reset()
return self
[文档]
def to_jittor(self, batch):
'''
用于将输入变量的列表转换为 Jittor 变量的方法。
参数:
- batch (list): 需要转换的输入变量的列表。
返回值:
- jt.Var: 一个包含输入变量的列表的 Jittor 变量。
代码示例:
>>> batch = [np.array([1,2,3]), np.array([4,5,6])]
>>> to_jittor(batch)
>>> [jt.Var([1, 2, 3]), jt.Var([4, 5, 6])]
'''
if self.keep_numpy_array: return batch
if isinstance(batch, jt.Var): return batch
to_jt = lambda x: jt.array(x).stop_grad() \
if self.stop_grad else jt.array(x)
if isinstance(batch, np.ndarray):
return to_jt(batch)
if isinstance(batch, dict):
new_batch = {}
for k,v in batch.items():
new_batch[k] = self.to_jittor(v)
return new_batch
if not isinstance(batch, (list, tuple)):
return batch
new_batch = []
for a in batch:
if isinstance(a, np.ndarray):
new_batch.append(to_jt(a))
else:
new_batch.append(self.to_jittor(a))
return new_batch
[文档]
def collate_batch(self, batch):
'''
用于将输入变量的列表转换为 Jittor 变量的方法。默认情况下, 它会将输入变量的列表转换为 Jittor 变量的列表。如果输入变量的列表中的任何变量的维度为1, 则会将其从输出变量中去除。
参数:
- batch (list): 需要转换的输入变量的列表。
返回值:
- jt.Var: 一个包含输入变量的列表的 Jittor 变量。
代码示例:
>>> batch = [np.array([1,2,3]), np.array([4,5,6])]
>>> collate_batch(batch)
>>> jt.Var([[1, 2, 3], [4, 5, 6]])
'''
return collate_batch(batch)
[文档]
def terminate(self):
'''
终止数据集的工作进程。
代码示例:
>>> dataset.terminate()
'''
if hasattr(self, "workers"):
for w in self.workers:
w.p.terminate()
def _worker_main(self, worker_id, buffer, status):
import jittor_utils
jt.flags.use_cuda_host_allocator = 0
jittor_utils.cc.init_subprocess()
jt.jt_init_subprocess()
seed = jt.get_seed()
wseed = (seed ^ (worker_id*1167)) ^ 1234
jt.set_global_seed(wseed)
# parallel_op_compiler still problematic,
# it is not work on ubuntu 16.04. but worked on ubuntu 20.04
# it seems like the static value of parallel compiler
# is not correctly init.
jt.flags.use_parallel_op_compiler = 0
import time
try:
gid_obj = self.gid.get_obj()
gid_lock = self.gid.get_lock()
start = time.time()
while True:
# get id
with gid_lock:
while buffer.is_stop() or self.idqueue.is_stop() or \
gid_obj.value >= self.batch_len:
self.num_idle.value += 1
self.num_idle_c.notify()
self.gidc.wait()
self.num_idle.value -= 1
cid = gid_obj.value
batch_index_list = self.index_list_numpy[
cid*self.real_batch_size:
min(self.real_len, (cid+1)*self.real_batch_size)
].copy()
gid_obj.value += 1
with self.idqueue_lock:
self.idqueue.push(worker_id)
now = time.time()
other_time = now - start
start = now
# load and transform data
batch = []
if mp_log_v:
print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size))
for i in batch_index_list:
batch.append(self[i])
batch = self.collate_batch(batch)
now = time.time()
data_time = now - start
start = now
# send data to main process
if mp_log_v:
print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer)
try:
buffer.send(batch)
except:
if buffer.is_stop():
continue
raise
now = time.time()
send_time = now - start
start = now
status[0], status[1], status[2], status[3], status[4] = \
other_time, data_time, send_time, \
other_time + data_time + send_time, \
img_open_hook.duration
img_open_hook.duration = 0.0
except:
import traceback
line = traceback.format_exc()
print(line)
os.kill(os.getppid(), signal.SIGINT)
exit(0)
def display_worker_status(self):
if not hasattr(self, "workers"):
return
msg = [""]
msg.append(f"progress:{self.batch_id}/{self.batch_len}")
msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}")
msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}")
msg.append(f"last 10 workers: {self.last_ids}")
msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)")
for i in range(self.num_workers):
w = self.workers[i]
s = w.status
msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}")
LOG.i('\n'.join(msg))
def _stop_all_workers(self):
# stop workers
for w in self.workers:
w.buffer.stop()
self.idqueue.stop()
# wait until all workers idle
if self.num_idle.value < self.num_workers:
with self.gid.get_lock():
self.gid.get_obj().value = self.batch_len
if mp_log_v:
print("idle num", self.num_idle.value)
while self.num_idle.value < self.num_workers:
self.num_idle_c.wait()
if mp_log_v:
print("idle num", self.num_idle.value)
# clean workers' buffer
for w in self.workers:
w.buffer.clear()
self.idqueue.clear()
self.gid.value = 0
def _init_workers(self, index_list):
jt.migrate_all_to_cpu()
jt.clean()
jt.gc()
self.index_list = mp.Array('i', self.real_len, lock=False)
workers = []
# get worker id
self.idqueue = jt.RingBuffer(2048)
self.idqueue_lock = mp.Lock()
# global token index
self.gid = mp.Value('i', self.batch_len)
self.gid.value = 0
# global token index condition
self.gidc = mp.Condition(self.gid.get_lock())
# number of idle workers
self.num_idle = mp.Value('i', 0, lock=False)
# number of idle workers condition
self.num_idle_c = mp.Condition(self.gid.get_lock())
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
self.index_list_numpy[:] = index_list
for i in range(self.num_workers):
w = Worker(target=self._worker_main, args=(i,),
buffer_size=self.buffer_size,
keep_numpy_array=self.keep_numpy_array)
workers.append(w)
self.workers = workers
def reset(self):
if not hasattr(self, "workers"):
return
self._stop_all_workers()
self.terminate()
del self.index_list
del self.idqueue
del self.idqueue_lock
del self.gid
del self.gidc
del self.num_idle
del self.num_idle_c
del self.workers
del self.index_list_numpy
def __del__(self):
if mp_log_v:
print("dataset deleted")
try:
self.terminate()
except:
pass
def __deepcopy__(self, memo=None, _nil=[]):
from copy import deepcopy
if memo is None:
memo = {}
d = id(self)
y = memo.get(d, _nil)
if y is not _nil:
return y
obj = self.__class__.__new__(self.__class__)
memo[d] = id(obj)
exclude_key = {"index_list", "idqueue", "idqueue_lock", "gid", "gidc", "num_idle", "num_idle_c", "workers", "index_list_numpy", "dataset", "idqueue", "idqueue_lock"}
for k,v in self.__dict__.items():
if k in exclude_key: continue
obj.__setattr__(k, deepcopy(v))
obj.dataset = obj
return obj
def __real_len__(self):
if self.total_len is None:
self.total_len = len(self)
return self.total_len
def _get_index_list(self):
if self.total_len is None:
self.total_len = len(self)
# maybe rewrite by sampler
total_len = self.total_len
if self.sampler:
index_list = list(self.sampler.__iter__())
total_len = len(index_list)
# check is not batch sampler
if len(index_list):
assert not isinstance(index_list[0], (list,tuple)), "Batch sampler not support yet."
elif self.shuffle == False:
index_list = get_order_list(self.total_len)
else:
# using _shuffle_rng to generate multiprocess
# consist shuffle list
# index_list = get_random_list(self.total_len)
index_list = self._shuffle_rng.permutation(range(self.total_len))
# scatter index_list for all mpi process
# scatter rule:
# batch 1 batch 2
# [........] [........] ...
# 00011122 00011122
# if last batch is smaller than world_size
# pad to world_size
# last batch
# [.] -> [012]
if jt.in_mpi:
world_size = mpi.world_size()
world_rank = mpi.world_rank()
index_list = np.int32(index_list)
# TODO: mpi broadcast in subprocess has bug, fix it
# mpi.broadcast(index_list, 0)
assert self.batch_size >= world_size, \
f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})"
real_batch_size = (self.batch_size-1) // world_size + 1
if real_batch_size * world_size != self.batch_size:
LOG.w("Batch size is not divisible by MPI world size, "
"The distributed version may be different from "
"the single-process version.")
fix_batch = total_len // self.batch_size
last_batch = total_len - fix_batch * self.batch_size
fix_batch_l = index_list[0:fix_batch*self.batch_size] \
.reshape(-1,self.batch_size)
fix_batch_l = fix_batch_l[
:,real_batch_size*world_rank:real_batch_size*(world_rank+1)]
real_batch_size = fix_batch_l.shape[1]
fix_batch_l = fix_batch_l.flatten()
if not self.drop_last and last_batch > 0:
last_batch_l = index_list[-last_batch:]
real_last_batch = (last_batch-1)//world_size+1
l = real_last_batch * world_rank
r = l + real_last_batch
if r > last_batch:
r = last_batch
l = r-real_last_batch
index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]])
else:
index_list = fix_batch_l
self.real_len = len(index_list)
self.real_batch_size = real_batch_size
# assert total_len // self.batch_size == \
# self.real_len // self.real_batch_size, f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}"
# print(f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}")
# print("mpi dataset init ")
else:
self.real_len = len(index_list)
self.real_batch_size = self.batch_size
if self.drop_last:
self.batch_len = self.real_len // self.real_batch_size
else:
self.batch_len = (self.real_len-1) // self.real_batch_size + 1
return index_list
def _epochs(self):
if self.endless:
while True:
yield
self.epoch_id += 1
else:
yield
def __iter__(self):
if self._disable_workers:
self.num_workers = 0
index_list = self._get_index_list()
if not hasattr(self, "workers") and self.num_workers:
self._init_workers(index_list)
self.last_ids = [-1] * 10
if self.num_workers:
start = time.time()
self.batch_time = 0
gid_obj = self.gid.get_obj()
gid_lock = self.gid.get_lock()
for _ in self._epochs():
with gid_lock:
if self.num_idle.value:
self.gidc.notify_all()
for i in range(self.batch_len):
if self.num_idle.value:
with gid_lock:
if self.num_idle.value and \
gid_obj.value >= self.batch_len:
index_list = self._get_index_list()
self.index_list_numpy[:] = index_list
gid_obj.value = 0
self.gidc.notify_all()
# get which worker has this batch
worker_id = self.idqueue.pop()
now = time.time()
self.wait_time = now - start
start = now
self.last_ids[i%10] = worker_id
self.batch_id = i
w = self.workers[worker_id]
if mp_log_v:
print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
batch = w.buffer.recv()
now = time.time()
self.recv_time = now - start
start = now
if mp_log_v:
print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
batch = self.to_jittor(batch)
now = time.time()
self.to_jittor_time = now - start
start = now
yield batch
now = time.time()
self.batch_time = now - start
start = now
if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0:
jt.display_memory_info()
else:
for _ in self._epochs():
self.batch_id = 0
batch_data = []
for idx in index_list:
batch_data.append(self[int(idx)])
if len(batch_data) == self.real_batch_size:
batch_data = self.collate_batch(batch_data)
tmp = batch_data
batch_data = self.to_jittor(batch_data)
# breakpoint()
yield batch_data
self.batch_id += 1
if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0:
jt.display_memory_info()
batch_data = []
# depend on drop_last
if not self.drop_last and len(batch_data) > 0:
batch_data = self.collate_batch(batch_data)
batch_data = self.to_jittor(batch_data)
self.batch_id += 1
yield batch_data
[文档]
def DataLoader(dataset: Dataset, *args, **kargs):
'''
DataLoader结合了数据集和采样器, 并在给定的数据集上提供可迭代性。
参数:
- dataset(Dataset): 需要进行封装的数据集。
- args: 数据集属性设置的可变位置参数。
- kargs: 数据集属性设置的可变关键字参数, 使用字典的形式传递参数值。
返回值:
封装好的数据加载器( ``DataLoader`` )。
代码示例:
>>> from jittor.dataset.cifar import CIFAR10
>>> from jittor.dataset import DataLoader
>>> train_dataset = CIFAR10()
>>> dataloader = DataLoader(train_dataset, batch_size=8)
>>> for batch_idx, (x_, target) in enumerate(dataloader):
>>> # 处理每个批次的数据
'''
return dataset.set_attrs(*args, **kargs)
[文档]
class ImageFolder(Dataset):
'''
从目录中加载图像及其标签用于图像分类的数据集。
数据集的目录结构应如下所示:
* root/label1/img1.png
* root/label1/img2.png
* ...
* root/label2/img1.png
* root/label2/img2.png
* ...
参数:
- root (str): 包含图像和标签子目录的根目录的路径
- transform (callable, optional): 用于对样本进行转换的optional转换操作(例如, 数据增强)。默认值: None
属性:
- classes (list): 类名的列表
- class_to_idx (dict): 从类名映射到类索引的字典
- imgs (list): 包含(image_path, class_index)元组的列表
代码示例:
>>> from jittor.dataset import ImageFolder
>>> dataset = ImageFolder(root="path_to_your_dataset")
'''
def __init__(self, root, transform=None):
super().__init__()
self.root = root
self.transform = transform
self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()])
self.class_to_idx = {v:k for k,v in enumerate(self.classes)}
self.imgs = []
image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff'))
for i, class_name in enumerate(self.classes):
class_dir = os.path.join(root, class_name)
for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)):
for fname in sorted(fnames):
if os.path.splitext(fname)[-1].lower() in image_exts:
path = os.path.join(class_dir, fname)
self.imgs.append((path, i))
LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.")
self.set_attrs(total_len=len(self.imgs))
def __getitem__(self, k):
with open(self.imgs[k][0], 'rb') as f:
img = Image.open(f).convert('RGB')
if self.transform:
img = self.transform(img)
return img, self.imgs[k][1]
[文档]
class VarDataset(Dataset):
'''
使用 Var 对象直接创建数据集的类, TensorDataset 是 VarDataset 的别名。这个类允许用户直接从 Jittor 变量中创建数据集, 而无需对数据执行任何预处理。数据集中的每个元素都是根据相应的索引从给定的 Jittor 变量中提取的。所有输入变量的第一个维度长度必须相等, 否则创建 VarDataset 时会触发错误。
参数:
- *args (jt.Var): 一个或多个 Jittor 变量。所有变量的长度必须相同, 且变量的维度数可以是任意的。这些变量将会并行被索引, 以创建数据集中的条目。
代码示例:
>>> import jittor as jt
>>> from jittor.dataset import VarDataset
>>> x = jt.array([1,2,3])
>>> y = jt.array([4,5,6])
>>> z = jt.array([7,8,9])
>>> dataset = VarDataset(x, y, z)
>>> dataset.set_attrs(batch_size=1)
>>> for a, b, c in dataset:
>>> print(a, b, c)
>>> # 1, 4, 7
>>> # 2, 5, 8
>>> # 3, 6, 9
'''
def __init__(self, *args):
super().__init__()
self.args = args
self._disable_workers = True
assert len(args), "At lease one args"
l = len(args[0])
for a in args:
assert l == len(a), "Len should be the same"
self.set_attrs(total_len=l)
def __getitem__(self, idx):
return [ a[idx] for a in self.args ]
[文档]
def collate_batch(self, batch):
'''
用于将输入的Jittor变量的列表转换为特定格式 Jittor 变量的方法。
参数:
- batch (list of jt.Var): 需要转换的输入变量的列表。
返回值:
- jt.Var: 一个包含输入变量的列表的 Jittor 变量。
代码示例:
>>> batch = [jt.array([1,2,3]), jt.array([4,5,6])]
>>> collate_batch(batch)
>>> jt.Var([[1, 2, 3], [4, 5, 6]])
'''
b = collate_batch(batch)
for i in range(len(self.args)):
x = b[i]
if jt.is_var(self.args[i]) and self.args[i].ndim == 1:
x.assign(x.squeeze(-1))
return b
TensorDataset = VarDataset