# ***************************************************************
# Copyright(c) 2019
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import numpy as np
import os
from PIL import Image
from .dataset import Dataset, dataset_root
[文档]
class VOC(Dataset):
'''
`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ 数据集
参数:
- data_root (str): 数据集的根目录
- split (str, optional): 选择数据集的子集, 'train'表示训练集, 'val'表示验证集。默认值: 'train'
属性:
- data_root (str): 数据集的根目录
- split (str): 数据集的子集
- image_root (str): 图像文件夹的路径
- label_root (str): 标签文件夹的路径
- data_list_path (str): 数据列表文件的路径
- image_path (list of str): 图像文件的路径列表
- label_path (list of str): 标签文件的路径列表
代码示例:
>>> from jittor.dataset.voc import VOC
>>> train_loader = VOC(data_root='path/to/VOC').set_attrs(batch_size=16, shuffle=True)
>>> for i, (imgs, target) in enumerate(train_loader):
>>> # 处理图像和标签
'''
NUM_CLASSES = 21
def __init__(self, data_root=dataset_root+'/voc/', split='train'):
super().__init__()
''' total_len , batch_size, shuffle must be set '''
self.data_root = data_root
self.split = split
self.image_root = os.path.join(data_root, 'JPEGImages')
self.label_root = os.path.join(data_root, 'SegmentationClass')
self.data_list_path = os.path.join(self.data_root, 'ImageSets', 'Segmentation', self.split + '.txt')
self.image_path = []
self.label_path = []
with open(self.data_list_path, "r") as f:
lines = f.read().splitlines()
for idx, line in enumerate(lines):
_img_path = os.path.join(self.image_root, line + '.jpg')
_label_path = os.path.join(self.label_root, line + '.png')
assert os.path.isfile(_img_path)
assert os.path.isfile(_label_path)
self.image_path.append(_img_path)
self.label_path.append(_label_path)
self.set_attrs(total_len = len(self.image_path))
def __getitem__(self, index):
_img = Image.open(self.image_path[index])
_label = Image.open(self.label_path[index])
_img = _img.resize((513, 513))
_label = _label.resize((513, 513))
_img = np.array(_img)
_label = np.array(_label)
_img = _img.transpose(2,0,1)
return _img, _label