-
data_loaderBase Line/python 기초 코드 2022. 4. 4. 15:29
# Created by ylab604 # ------------------------------------------------------------------------------ import os import random import glob import torch import torch.utils.data as data import pandas as pd from PIL import Image, ImageFile import numpy as np import torchvision.transforms as transforms from lib.utils.transforms import fliplr_joints, crop, generate_target, transform_pixel ImageFile.LOAD_TRUNCATED_IMAGES = True # -------------------------------------------------------------------- # 이미지부르고 노말 부르고 뎁스 부르고 # 텐서화 하고 # get item에 이미지 노말 뎁스 리턴 # -------------------------------------------------------------------- class Thuman(data.Dataset): """thuman """ def __init__( self, cfg, is_train=True, render_transforms=None, normal_transforms=None, depth_transforms=None, ): self.data_root = cfg.DATASET.ROOT self.input_size = cfg.MODEL.IMAGE_SIZE self.output_size = cfg.MODEL.HEATMAP_SIZE self.sigma = cfg.MODEL.SIGMA # default = 가우시안 self.label_type = cfg.MODEL.TARGET_TYPE self.render_transforms = transforms.Compose(render_transforms) self.normal_transforms = transforms.Compose(normal_transforms) self.depth_transforms = transforms.Compose(depth_transforms) # print(type(self.render_transforms)) ############################################################################ data_root_path = self.data_root+"/*" data = glob.glob(data_root_path) #get renderdata render_file_list=[] for n in range(len(data)): j = "{:04d}".format(n) i = j + "_OBJ" path = self.data_root + i + "/RENDER/" + j + "/*" # print(path) render_file_list_1=glob.glob(path) render_file_list+=render_file_list_1 #print(1) #print(render_file_list) render_file_list = sorted(render_file_list) # #print(render_file_list) self.render_files=render_file_list #get render_normal_data render_normal_file_list=[] for n in range(len(data)): j = "{:04d}".format(n) i = j + "_OBJ" path = self.data_root + i + "/RENDER_NORMAL/" + j + "/*" # print(path) render_normal_file_list_1=glob.glob(path) render_normal_file_list+=render_normal_file_list_1 #print(file_list) render_normal_file_list = sorted(render_normal_file_list) self.normal_files=render_normal_file_list #get render_depth_data render_depth_file_list=[] for n in range(len(data)): j = "{:04d}".format(n) i = j + "_OBJ" path = self.data_root + i + "/RENDER_DEPTH/" + j + "/*" # print(path) #render_depth_file_list.append(glob.glob(path)) render_depth_file_list_1=glob.glob(path) render_depth_file_list+=render_depth_file_list_1 #print(file_list) render_depth_file_list = sorted(render_depth_file_list) self.depth_files=render_depth_file_list ############################################################################ def __len__(self): return len(self.render_files) def __getitem__(self, index): # we need target and image render_img = Image.open( self.render_files[index % len(self.render_files)] ).convert("RGB") # print(render_img) normal_img = Image.open( self.normal_files[index % len(self.normal_files)] ).convert("RGB") depth_img = Image.open(self.depth_files[index % len(self.depth_files)]).convert( "RGB" ) # print(1) # print(type(self.render_transforms())) render_img = self.render_transforms(render_img) # print(self.render_files) normal_img = self.normal_transforms(normal_img) depth_img = self.depth_transforms(depth_img) img = render_img target = normal_img return {"A": img, "B": target, "C": depth_img} if __name__ == "__main__": pass
짚고넘어가야할 것
3자리수 4자리수 "{03:d}".format() "{04:d}".format()
리스트 합치기 append가 아니라 그냥 더하기
'Base Line > python 기초 코드' 카테고리의 다른 글
pcl + open3d (0) 2022.06.06 pillow 이미지 rotation, flip (1) 2022.04.18 Depth_Render.py (0) 2022.03.16 open3d extract depth map from mesh (2) 2022.03.14 path & os (0) 2022.03.09