jittor.dataset.voc 源代码

# ***************************************************************
# Copyright(c) 2019
#     Meng-Hao Guo <guomenghao1997@gmail.com>
#     Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

import numpy as np
import os
from PIL import Image
from .dataset import Dataset, dataset_root

[文档] class VOC(Dataset): ''' `Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ 数据集 参数: - data_root (str): 数据集的根目录 - split (str, optional): 选择数据集的子集, 'train'表示训练集, 'val'表示验证集。默认值: 'train' 属性: - data_root (str): 数据集的根目录 - split (str): 数据集的子集 - image_root (str): 图像文件夹的路径 - label_root (str): 标签文件夹的路径 - data_list_path (str): 数据列表文件的路径 - image_path (list of str): 图像文件的路径列表 - label_path (list of str): 标签文件的路径列表 代码示例: >>> from jittor.dataset.voc import VOC >>> train_loader = VOC(data_root='path/to/VOC').set_attrs(batch_size=16, shuffle=True) >>> for i, (imgs, target) in enumerate(train_loader): >>> # 处理图像和标签 ''' NUM_CLASSES = 21 def __init__(self, data_root=dataset_root+'/voc/', split='train'): super().__init__() ''' total_len , batch_size, shuffle must be set ''' self.data_root = data_root self.split = split self.image_root = os.path.join(data_root, 'JPEGImages') self.label_root = os.path.join(data_root, 'SegmentationClass') self.data_list_path = os.path.join(self.data_root, 'ImageSets', 'Segmentation', self.split + '.txt') self.image_path = [] self.label_path = [] with open(self.data_list_path, "r") as f: lines = f.read().splitlines() for idx, line in enumerate(lines): _img_path = os.path.join(self.image_root, line + '.jpg') _label_path = os.path.join(self.label_root, line + '.png') assert os.path.isfile(_img_path) assert os.path.isfile(_label_path) self.image_path.append(_img_path) self.label_path.append(_label_path) self.set_attrs(total_len = len(self.image_path)) def __getitem__(self, index): _img = Image.open(self.image_path[index]) _label = Image.open(self.label_path[index]) _img = _img.resize((513, 513)) _label = _label.resize((513, 513)) _img = np.array(_img) _label = np.array(_label) _img = _img.transpose(2,0,1) return _img, _label