图像质量评价指标: MMD ( maximum

您所在的位置:网站首页 discrepancy和disparity 图像质量评价指标: MMD ( maximum

图像质量评价指标: MMD ( maximum

2023-11-24 16:12| 来源: 网络整理| 查看: 265

MMD:maximum mean discrepancy。最大平均差异, 用于判断两个分布p和q是否相同。它的基本假设是:如果对于所有以分布生成的样本空间为输入的函数f,如果两个分布生成的足够多的样本在f上的对应的像的均值都相等,那么那么可以认为这两个分布是同一个分布。现在一般用于度量两个分布之间的相似性。

Keras 2.2.4 tensorflow 1.9.0

import torch import matplotlib import os import argparse import numpy as np from PIL import Image from torch.autograd import Variable from keras.applications.inception_v3 import InceptionV3 os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3' # 只显示 Error def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): ''' 将源域数据和目标域数据转化为核矩阵,即上文中的K Params: source: 源域数据(n * len(x)) target: 目标域数据(m * len(y)) kernel_mul: kernel_num: 取不同高斯核的数量 fix_sigma: 不同高斯核的sigma值 Return: sum(kernel_val): 多个核矩阵之和 ''' n_samples = int(source.size()[0])+int(target.size()[0])# 求矩阵的行数,一般source和target的尺度是一样的,这样便于计算 total = torch.cat([source, target], dim=0)#将source,target按列方向合并 #将total复制(n+m)份 total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) #将total的每一行都复制成(n+m)行,即每个数据都扩展成(n+m)份 total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) #求任意两个数据之间的和,得到的矩阵中坐标(i,j)代表total中第i行数据和第j行数据之间的l2 distance(i==j时为0) L2_distance = ((total0-total1)**2).sum(2) #调整高斯核函数的sigma值 if fix_sigma: bandwidth = fix_sigma else: bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) #以fix_sigma为中值,以kernel_mul为倍数取kernel_num个bandwidth值(比如fix_sigma为1时,得到[0.25,0.5,1,2,4] bandwidth /= kernel_mul ** (kernel_num // 2) bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] #高斯核函数的数学表达式 kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] #得到最终的核矩阵 return sum(kernel_val)#/len(kernel_val) def mmd_rbf(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): ''' 计算源域数据和目标域数据的MMD距离 Params: source: 源域数据(n * len(x)) target: 目标域数据(m * len(y)) kernel_mul: kernel_num: 取不同高斯核的数量 fix_sigma: 不同高斯核的sigma值 Return: loss: MMD loss ''' batch_size = int(source.size()[0]) #一般默认为源域和目标域的batchsize相同 kernels = guassian_kernel(source, target, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) #根据式(3)将核矩阵分成4部分 XX = kernels[:batch_size, :batch_size] YY = kernels[batch_size:, batch_size:] XY = kernels[:batch_size, batch_size:] YX = kernels[batch_size:, :batch_size] loss = torch.mean(XX + YY - XY -YX) return loss#因为一般都是n==m,所以L矩阵一般不加入计算 def data_list(dirPath): # read img generatedImgs = [] realImgs = [] for root, dirs, files in os.walk(dirPath): for filename in sorted(files): # 判断该文件是否是目标文件 if "generated" in filename: generatedPath = root + '/' + filename generatedImgs.append(readImg(generatedPath)) # 对比图片路径 realPath = root + '/' + filename.replace('generated', 'real') realImgs.append(readImg(realPath)) return generatedImgs, realImgs def readImg(imgPath): img = Image.open(imgPath) # RGB # img.show() # PIL转numpy类型 img = np.array(img).astype(np.float) return img/255 if __name__ == '__main__': ### 参数设定 parser = argparse.ArgumentParser() parser.add_argument('--dataset_dir', type=str, default=r'D:\Project\pix2pix-master\results', help='results') parser.add_argument('--name', type=str, default='faces', help='name of dataset') opt = parser.parse_args() # 数据集 dirPath = os.path.join(opt.dataset_dir, opt.name) generatedImgs, realImgs = data_list(dirPath) size = len(generatedImgs) print("数据集:", size) X = torch.Tensor(generatedImgs) Y = torch.Tensor(realImgs) print('shape: ', X.shape, Y.shape) # prepare the inception v3 model model = InceptionV3(include_top=False, pooling='avg') X, Y = model.predict(X), model.predict(Y) X, Y = Variable(torch.Tensor(X)), Variable(torch.Tensor(Y)) mmd = mmd_rbf(X, Y) print("mmd: ", mmd)


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3