jittor.dataset.voc 源代码

# ***************************************************************
# Copyright(c) 2019
#     Meng-Hao Guo <guomenghao1997@gmail.com>
#     Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

import numpy as np
import os
from PIL import Image
from .dataset import Dataset, dataset_root

[文档]class VOC(Dataset): ''' Jittor's own class for loading VOC dataset. Args:: [in] data_root(str): your data root. [in] split(str): which split you want to use, train or val. Attribute:: NUM_CLASSES: Number of total categories, default is 21. Example:: from jittor.dataset.voc import VOC train_loader = VOC(data_root='...').set_attrs(batch_size=16, shuffle=True) for i, (imgs, target) in enumerate(train_loader): ... ''' NUM_CLASSES = 21 def __init__(self, data_root=dataset_root+'/voc/', split='train'): super().__init__() ''' total_len , batch_size, shuffle must be set ''' self.data_root = data_root self.split = split self.image_root = os.path.join(data_root, 'JPEGImages') self.label_root = os.path.join(data_root, 'SegmentationClass') self.data_list_path = os.path.join(self.data_root, 'ImageSets', 'Segmentation', self.split + '.txt') self.image_path = [] self.label_path = [] with open(self.data_list_path, "r") as f: lines = f.read().splitlines() for idx, line in enumerate(lines): _img_path = os.path.join(self.image_root, line + '.jpg') _label_path = os.path.join(self.label_root, line + '.png') assert os.path.isfile(_img_path) assert os.path.isfile(_label_path) self.image_path.append(_img_path) self.label_path.append(_label_path) self.set_attrs(total_len = len(self.image_path)) def __getitem__(self, index): _img = Image.open(self.image_path[index]) _label = Image.open(self.label_path[index]) _img = _img.resize((513, 513)) _label = _label.resize((513, 513)) _img = np.array(_img) _label = np.array(_label) _img = _img.transpose(2,0,1) return _img, _label