Bringing Old Photos Back to Life模型代码分析1(数据载入部分)

您所在的位置:网站首页 jq生成随机数 Bringing Old Photos Back to Life模型代码分析1(数据载入部分)

Bringing Old Photos Back to Life模型代码分析1(数据载入部分)

2023-02-14 00:27| 来源: 网络整理| 查看: 265

Bringing Old Photos Back to Life模型代码分析1(数据载入部分) 原创

Aclous 2022-01-05 11:12:09 博主文章分类:深度学习 ©著作权

文章标签 python 人工智能 图像处理 深度学习 计算机视觉 文章分类 Java 后端开发

©著作权归作者所有:来自51CTO博客作者Aclous的原创作品,请联系作者获取转载授权,否则将追究法律责任

(1) Bringing Old Photos Back to Life原理和测试        

(2)

Bringing Old Photos Back to Life模型代码分析1(数据载入部分) 

Bringing Old Photos Back to Life模型代码分析2(模型部分)

(3) Bringing Old Photos Back to Life数据集及其训练

这一部分是关于数据预处理部分

文件在Global/data下,如图所示

Bringing Old Photos Back to Life模型代码分析1(数据载入部分)_人工智能

 base_dataset.py

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms import numpy as np import random # class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() def name(self): return 'BaseDataset' def initialize(self, opt): pass #获取裁剪参数 # 这个函数是根据用户指定的方式resize或者crop出合适大小的输入尺寸。 # size:输入图片的尺寸 def get_params(opt, size): w, h = size new_h = h new_w = w if opt.resize_or_crop == 'resize_and_crop': # opt.loadSize为自己输入的尺寸,将图像缩放到这个大小 new_h = new_w = opt.loadSize # 将宽和高设置为同样大小 if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256 if w 0.5 # 随机数是否大于0.5,flip是bool型变量,此行代码意思为随机生成True或者False return {'crop_pos': (x, y), 'flip': flip} # 最终的返回值,在data.aligned_dataset 45行,当作params传入了下方get_transform()函数 # 图像变换 def get_transform(opt, params, method=Image.BICUBIC, normalize=True): transform_list = [] #重设置大小 if 'resize' in opt.resize_or_crop: # # 若opt.resize_or_crop中有'resize' osize = [opt.loadSize, opt.loadSize] transform_list.append(transforms.Scale(osize, method)) # elif 'scale_width' in opt.resize_or_crop: # transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) ## Here , We want the shorter side to match 256, and Scale will finish it. #将输入的`PIL.Image`重新改变大小成给定的`size`即256 transform_list.append(transforms.Scale(256,method)) #裁剪 if 'crop' in opt.resize_or_crop: if opt.isTrain: # 使用transforms.Lambda封装其为transforms策略 transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) else: if opt.test_random_crop: transform_list.append(transforms.RandomCrop(opt.fineSize)) else: transform_list.append(transforms.CenterCrop(opt.fineSize)) ## when testing, for ablation study, choose center_crop directly. if opt.resize_or_crop == 'none': base = float(2 ** opt.n_downsample_global) if opt.netG == 'local': base *= (2 ** opt.n_local_enhancers) transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) if opt.isTrain and not opt.no_flip: transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) transform_list += [transforms.ToTensor()] if normalize: # mean和std均为0.5 transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) #归一化到(-1,1) def normalize(): return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #将图片进行设置大小为base的整倍数 def __make_power_2(img, base, method=Image.BICUBIC): ow, oh = img.size h = int(round(oh / base) * base) w = int(round(ow / base) * base) if (h == oh) and (w == ow): return img return img.resize((w, h), method) #修改图片为目标大小 def __scale_width(img, target_width, method=Image.BICUBIC): ow, oh = img.size if (ow == target_width): return img w = target_width h = int(target_width * oh / ow) return img.resize((w, h), method) #对图片进行切割# 随机平移滑动裁剪 def __crop(img, pos, size): ow, oh = img.size x1, y1 = pos tw = th = size # 输入的尺寸 opt.fineSize if (ow > tw or oh > th): #Image.crop(left, up, right, below) 其中left:与左边界的距离 up:与上边界的距离 right:还是与左边界的距离 below:还是与上边界的距离 return img.crop((x1, y1, x1 + tw, y1 + th)) # 随机裁剪,因为虽然每次裁剪测大小一样,但是起始点位置不一样 return img #左右翻转 def __flip(img, flip): if flip: return img.transpose(Image.FLIP_LEFT_RIGHT) return img

Create_Bigfile.py

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os import struct from PIL import Image IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] # 判断文件夹中是否有以上类型图片,没有则返回0 def is_image_file(filename): #如果不都为空、0、false,则any()返回true return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) #创建图片数据集,存在列表中并返回 def make_dataset(dir): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir # os.walk(top[, topdown=True[, onerror=None[, followlinks=False]]]) 通过在目录树中游走输出在目录中的文件名,top返回三项(root,dirs,files),分别代表: # 当前正在遍历的这个文件夹的本身的地址; list类型,内容是该文件夹中所有的目录的名字(不包括子目录); list类型,内容是该文件夹中所有的文件(不包括子目录) for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): #print(fname) #拼接出图片的地址,并加入到images列表 path = os.path.join(root, fname) images.append(path) return images ### Modify these 3 lines in your own environment #需要修改以下三个变量: #变量一:存放待训练数据集文件夹的父目录 indir="/home/ziyuwan/workspace/data/temp_old" #变量二:待训练数据的文件夹,共有三个目标文件夹,分别为 : VOC数据集(用于生成假老照片)、真实黑白老照片、真实彩色老照片 target_folders=['VOC','Real_L_old','Real_RGB_old'] #变量三:输出生成结果的文件夹路径 out_dir ="/home/ziyuwan/workspace/data/temp_old" ### if os.path.exists(out_dir) is False: os.makedirs(out_dir) #遍历存放数据集的文件夹 for target_folder in target_folders: #拼接生成存放数据集文件夹的路径 curr_indir = os.path.join(indir, target_folder) #生成的大文件路径(含问文件名) curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile' % (target_folder))) image_lists = make_dataset(curr_indir) image_lists.sort() with open(curr_out_file, 'wb') as wfid: # write total image number wfid.write(struct.pack('i', len(image_lists))) for i, img_path in enumerate(image_lists): # write file name first img_name = os.path.basename(img_path) img_name_bytes = img_name.encode('utf-8') wfid.write(struct.pack('i', len(img_name_bytes))) wfid.write(img_name_bytes) # # # write image data in with open(img_path, 'rb') as img_fid: img_bytes = img_fid.read() wfid.write(struct.pack('i', len(img_bytes))) wfid.write(img_bytes) if i % 1000 == 0: print('write %d images done' % i)

custom_dataset_data_loader.py

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch.utils.data import random from data.base_data_loader import BaseDataLoader from data import online_dataset_for_old_photos as dts_ray_bigfile #根据训练的模型模块不同,返回对应的数据集 def CreateDataset(opt): dataset = None # 训练A或者B时,使用的数据集为非成对数据集 if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B': dataset = dts_ray_bigfile.UnPairOldPhotos_SR() #当训练mapping时,载入成对数据集 if opt.training_dataset=='mapping': if opt.random_hole: dataset = dts_ray_bigfile.PairOldPhotos_with_hole() else: dataset = dts_ray_bigfile.PairOldPhotos() print("dataset [%s] was created" % (dataset.name())) # 打印数据集名字为‘ dataset.initialize(opt) # 初始化数据集参数 return dataset # 返回创建好的数据集 ##创建数据载入器# 加载数据集 class CustomDatasetDataLoader(BaseDataLoader): def name(self): return 'CustomDatasetDataLoader' def initialize(self, opt): BaseDataLoader.initialize(self, opt) ## 初始化参数 #创建数据集 self.dataset = CreateDataset(opt) #创建数据载入器 self.dataloader = torch.utils.data.DataLoader( ## 加载创建好的数据集,并自定义相关参数 self.dataset, batch_size=opt.batchSize, shuffle=not opt.serial_batches, num_workers=int(opt.nThreads), drop_last=True) def load_data(self): return self.dataloader # 返回数据集 def __len__(self): return min(len(self.dataset), self.opt.max_dataset_size)# 返回加载的数据集长度和一个epoch容许的加载最大容量

data_loader.py

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. #创建数据载入器 ########################################################################## # 创建数据集加载主函数 ######################################################################## def CreateDataLoader(opt): from data.custom_dataset_data_loader import CustomDatasetDataLoader data_loader = CustomDatasetDataLoader() print(data_loader.name()) # 返回的名字为“CustomDatasetDataLoader” data_loader.initialize(opt) # # 初始化参数 return data_loader

image_foder.py

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch.utils.data as data from PIL import Image import os IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' ] def is_image_file(filename): ### any()函数用于判断给定的可迭代参数iterable是否全部为False,则返回False,如果有一个为True,则返回True。 # 元素除了是0、空、FALSE外都算TRUE。 # 函数等价于: # def any(iterable): # for element in iterable: # if element: # return True # return False return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) # 制作数据集:获得数据集的图片路径列表 def make_dataset(dir): # dir为数据集文件夹路径 images = []# 创建空列表 assert os.path.isdir(dir), '%s is not a valid directory' % dir # 确认路径存在 ### os.walk() 方法是一个简单易用的文件、目录遍历器,可以帮助我们高效的处理文件、目录方面的事情 # top -- 是你所要遍历的目录的地址, 返回的是一个三元组(root,dirs,files)。 # root 所指的是当前正在遍历的这个文件夹的本身的地址,和输入的os.walk(dir)种的dir一致 # dirs 是一个 list ,内容是该文件夹中所有的 目录 的名字(不包括子目录),若无则为[] # files 同样是 list , 内容是该文件夹中所有的 文件 的名字(不包括子目录),若无则为[] for root, _, fnames in sorted(os.walk(dir)): # fnames为文件中读取的照片文件 for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname)# 将文件夹路径dir 和 图片名称fname 结合起来 images.append(path) # 将图片路径存放到image列表里 return images # 返回图片路径列表 def default_loader(path): return Image.open(path).convert('RGB') class ImageFolder(data.Dataset): def __init__(self, root, transform=None, return_paths=False, loader=default_loader): imgs = make_dataset(root) # imgs为root目录下图片路径列表 if len(imgs) == 0: # 图片数量 = 0 报错 raise(RuntimeError("Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.transform = transform self.return_paths = return_paths self.loader = loader def __getitem__(self, index): path = self.imgs[index] # 获取指定图片路径 img = self.loader(path) # 加载图片 if self.transform is not None: img = self.transform(img) # 图片进行变换 if self.return_paths: return img, path # 返回图片和路径 else: return img # 仅返回图片 def __len__(self): return len(self.imgs) # 返回指定目录下图片数量

Load_Bigfile.py

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import io import os import struct from PIL import Image #载入打包好的数据 class BigFileMemoryLoader(object): def __load_bigfile(self): print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024)) with open(self.file_path, 'rb') as fid: self.img_num = struct.unpack('i', fid.read(4))[0] self.img_names = [] self.img_bytes = [] print('find total %d images' % self.img_num) for i in range(self.img_num): img_name_len = struct.unpack('i', fid.read(4))[0] img_name = fid.read(img_name_len).decode('utf-8') self.img_names.append(img_name) img_bytes_len = struct.unpack('i', fid.read(4))[0] self.img_bytes.append(fid.read(img_bytes_len)) if i % 5000 == 0: print('load %d images done' % i) print('load all %d images done' % self.img_num) #初始化 def __init__(self, file_path): super(BigFileMemoryLoader, self).__init__() self.file_path = file_path self.__load_bigfile() #返回图片名字和图片 def __getitem__(self, index): try: img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB') return self.img_names[index], img except Exception: print('Image read error for index %d: %s' % (index, self.img_names[index])) return self.__getitem__((index+1)%self.img_num) #图片数目 def __len__(self): return self.img_num

online_dataset_for_old_photos.py

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import os.path import io import zipfile from data.base_dataset import BaseDataset, get_params, get_transform, normalize from data.image_folder import make_dataset from PIL import Image import torchvision.transforms as transforms import numpy as np from data.Load_Bigfile import BigFileMemoryLoader import random import cv2 from io import BytesIO #图片转矩阵 def pil_to_np(img_PIL): '''Converts image in PIL format to np.array. From W x H x C [0...255] to C x W x H [0..1] ''' ar = np.array(img_PIL) if len(ar.shape) == 3: ar = ar.transpose(2, 0, 1) else: ar = ar[None, ...] return ar.astype(np.float32) / 255. #矩阵转图片 def np_to_pil(img_np): '''Converts image in np.array format to PIL image. From C x W x H [0..1] to W x H x C [0...255] ''' ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) if img_np.shape[0] == 1: ar = ar[0] else: ar = ar.transpose(1, 2, 0) return Image.fromarray(ar) ## #以下合成噪声图片 ## def synthesize_salt_pepper(image,amount,salt_vs_pepper): ## Give PIL, return the noisy PIL img_pil=pil_to_np(image) out = img_pil.copy() p = amount q = salt_vs_pepper flipped = np.random.choice([True, False], size=img_pil.shape, p=[p, 1 - p]) salted = np.random.choice([True, False], size=img_pil.shape, p=[q, 1 - q]) peppered = ~salted out[flipped & salted] = 1 out[flipped & peppered] = 0. noisy = np.clip(out, 0, 1).astype(np.float32) return np_to_pil(noisy) def synthesize_gaussian(image,std_l,std_r): ## Give PIL, return the noisy PIL img_pil=pil_to_np(image) mean=0 std=random.uniform(std_l/255.,std_r/255.) gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape) noisy=img_pil+gauss noisy=np.clip(noisy,0,1).astype(np.float32) return np_to_pil(noisy) def synthesize_speckle(image,std_l,std_r): ## Give PIL, return the noisy PIL img_pil=pil_to_np(image) mean=0 std=random.uniform(std_l/255.,std_r/255.) gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape) noisy=img_pil+gauss*img_pil noisy=np.clip(noisy,0,1).astype(np.float32) return np_to_pil(noisy) #图片缩小 def synthesize_low_resolution(img): w,h=img.size new_w=random.randint(int(w/2),w) new_h=random.randint(int(h/2),h) img=img.resize((new_w,new_h),Image.BICUBIC) if random.uniform(0,1)


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3