如何为TensorLayerX添加支持多计算后端的神经网络层

您所在的位置:网站首页 后端dao层 如何为TensorLayerX添加支持多计算后端的神经网络层

如何为TensorLayerX添加支持多计算后端的神经网络层

2023-04-04 08:31| 来源: 网络整理| 查看: 265

在TensorLayerX这篇文章的最后,我们介绍了如何自定义一个神经网络层。自定义神经网络层仅需要继承Module后重构init和forward方法。TensorLayerX是一款支持多框架的深度学习编程库,如何构建一个支持多框架的神经网络层呢?我们先对TensorLayerX的代码进行简单分析。

1.TensorLayerX代码结构分析TensorLayerX代码结构

首先我们对TensorLayerX整个工程进行分析。比较核心的有backend、dataflow、nn、model

backend:这个文件夹的文件是基于四个后端(tensorflow、mindspore、pytorch、paddle)提供的中低阶接口(如torch.nn.function、torch.add这类)进行统一封装,完成算子的四个后端实现。

dataflow:实现了数据处理的接口。

model:实现了模型训练的封装。

nn:实现了深度学习模型构建组件,包括初始化器,神经网络层,core里是神经网络基类的实现。nn是用户常用接口,调用的backend的算子。

因此,我们可以知道,要实现支持多计算后端的神经网络层,需要先在backend里分别实现四个后端的算子。然后在nn内调用backend算子。下面我们以实现线性层为例。

2.后端实现

我们看到在backend里,同一个计算后端分两个文件,如tensorflow_nn、tensorflow_backend。这两个文件将算子分类了。其中tensorflow_nn里封装的是nn相关操作,tensorflow_backend里封装的是算子基础操作。

现在我们以线性层为例分别实现四个后端。

TensorFlow后端:

import tensorflow as tf def linear(input, weight, bias = None): output = tf.matmul(input, weight, transpose_b=True) if bias: output = output + bias return output

MindSpore后端:

import mindspore.ops as P def linear(input, weight, bias = None): matmul = P.MatMul(transpose_b=True) output = matmul(input, weight) if bias: bias_add = P.BiasAdd() output = bias_add(output, bias) return output

PaddlePaddle后端:

import paddle.nn.functional as F def linear(input, weight, bias = None): return F.linear(input, weight, bias)

PyTorch后端:

import torch def linear(input, weight, bias = None): return torch.nn.functional.linear(input, weight, bias)3.前端实现

有了linear这个算子后,我们在前端就能统一调用。

import tensorlayerx as tlx from tensorlayerx.nn.core import Module class Linear(Module): def __init__( self, out_features, act=None, W_init='truncated_normal', b_init='constant', in_features=None, name=None, # 'linear', ): super(Linear, self).__init__(name, act=act) self.out_features = out_features self.W_init = self.str_to_init(W_init) self.b_init = self.str_to_init(b_init) self.in_features = in_features def build(self, inputs_shape): if self.in_features is None and len(inputs_shape)


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3