向量相似度评估方法

您所在的位置:网站首页 求相似度的方法是什么 向量相似度评估方法

向量相似度评估方法

2023-12-28 21:47| 来源: 网络整理| 查看: 265

原文链接:向量相似度评估方法

相似度在工作中的使用可以说是相当频繁,今天就带大家介绍pytorch中四种常用的向量相似度评估思路:

CosineSimilarityDotProductSimilarityBiLinearSimilarityMultiHeadedSimilarity 1 余弦相似度

余弦相似度相信大家都很熟悉了。用两个向量夹角的余弦值作为衡量两个个体间差异的大小。余弦值越接近1,就表明夹角越接近0度,也就是两个向量越相似。

import torch import torch.nn as nn import math class CosineSimilarity(nn.Module): def forward(self, tensor_1, tensor_2): normalized_tensor_1 = tensor_1 / tensor_1.norm(dim=-1, keepdim=True) normalized_tensor_2 = tensor_2 / tensor_2.norm(dim=-1, keepdim=True) return (normalized_tensor_1 * normalized_tensor_2).sum(dim=-1) 2 DotProductSimilarity

这个相似度函数计算每对向量之间的点积,并用可选的缩放来减少输出的方差,以调整结果的输出。

class DotProductSimilarity(nn.Module): def __init__(self, scale_output=False): super(DotProductSimilarity, self).__init__() self.scale_output = scale_output def forward(self, tensor_1, tensor_2): result = (tensor_1 * tensor_2).sum(dim=-1) if self.scale_output: # TODO why allennlp do multiplication at here ? result /= math.sqrt(tensor_1.size(-1)) return result

余弦法和点积法都是最常用的数学方法,在复杂场景下我们可以将神经网络的思路加入到计算相似度的方法中去。

3 BiLinearSimilarity

此相似度函数执行两个输入向量的双线性变换,就是加入了神经网络线性层。这个函数有一个权重矩阵“W”和一个偏差“b”,以及两个向量之间的相似度,计算公式为: x T W y + b x^TWy+b xTWy+b

计算后可使用激活函数,默认为不激活。

class BiLinearSimilarity(nn.Module): def __init__(self, tensor_1_dim, tensor_2_dim, activation=None): super(BiLinearSimilarity, self).__init__() self.weight_matrix = nn.Parameter(torch.Tensor(tensor_1_dim, tensor_2_dim)) self.bias = nn.Parameter(torch.Tensor(1)) self.activation = activation self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.weight_matrix) self.bias.data.fill_(0) def forward(self, tensor_1, tensor_2): intermediate = torch.matmul(tensor_1, self.weight_matrix) result = (intermediate * tensor_2).sum(dim=-1) + self.bias if self.activation is not None: result = self.activation(result) return result

根据此思路,我们可以演变出三线性变换,计算公式为: W T [ x , y , x ∗ y ] + b W^T[x,y,x*y]+b WT[x,y,x∗y]+b

只是在原基础上将各个特征及特征之间的关系都变为了输入,感兴趣的朋友们可以自行动手实现。

4 MultiHeadedSimilarity

这个相似度函数借用了transformer多“头”的思路来计算相似度。我们将输入张量投影到多个新张量中,并分别计算每个投影张量的相似度。

class MultiHeadedSimilarity(nn.Module): def __init__(self, num_heads, tensor_1_dim, tensor_1_projected_dim=None, tensor_2_dim=None, tensor_2_projected_dim=None, internal_similarity=DotProductSimilarity()): super(MultiHeadedSimilarity, self).__init__() self.num_heads = num_heads self.internal_similarity = internal_similarity tensor_1_projected_dim = tensor_1_projected_dim or tensor_1_dim tensor_2_dim = tensor_2_dim or tensor_1_dim tensor_2_projected_dim = tensor_2_projected_dim or tensor_2_dim if tensor_1_projected_dim % num_heads != 0: raise ValueError("Projected dimension not divisible by number of heads: %d, %d" % (tensor_1_projected_dim, num_heads)) if tensor_2_projected_dim % num_heads != 0: raise ValueError("Projected dimension not divisible by number of heads: %d, %d" % (tensor_2_projected_dim, num_heads)) self.tensor_1_projection = nn.Parameter(torch.Tensor(tensor_1_dim, tensor_1_projected_dim)) self.tensor_2_projection = nn.Parameter(torch.Tensor(tensor_2_dim, tensor_2_projected_dim)) self.reset_parameters() def reset_parameters(self): torch.nn.init.xavier_uniform_(self.tensor_1_projection) torch.nn.init.xavier_uniform_(self.tensor_2_projection) def forward(self, tensor_1, tensor_2): projected_tensor_1 = torch.matmul(tensor_1, self.tensor_1_projection) projected_tensor_2 = torch.matmul(tensor_2, self.tensor_2_projection) last_dim_size = projected_tensor_1.size(-1) // self.num_heads new_shape = list(projected_tensor_1.size())[:-1] + [self.num_heads, last_dim_size] split_tensor_1 = projected_tensor_1.view(*new_shape) last_dim_size = projected_tensor_2.size(-1) // self.num_heads new_shape = list(projected_tensor_2.size())[:-1] + [self.num_heads, last_dim_size] split_tensor_2 = projected_tensor_2.view(*new_shape) return self.internal_similarity(split_tensor_1, split_tensor_2) 总结

复杂的做法只是在向量的基础上进行了更多的线性变化和线性变化组合。其实我们自己完全可以根据业务场景自创计算方法,因为神经网络的好处就是在于我们可以随意自行搭建。

原文链接:向量相似度评估方法



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3