pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio) |
您所在的位置:网站首页 › torchscriptpt › pytorch模型运行到android手机上(仅使用pytorch+AndroidStudio) |
近期需要将pytorch模型运行到android手机上实验,在查阅网上博客后,发现大多数流程需要借助多个框架或软件,横跨多个编程语言、IDE。本文参考以下两篇博文,力求用更简洁的流程实现模型部署。 https://blog.csdn.net/xiaodidididi521/article/details/123985612 https://blog.csdn.net/m0_67391683/article/details/125401357 向两位作者表示感谢!本文进一步详细描述了实现流程。 一、pytorch模型转化pytorch模型无法直接被Android调用,需要转化为特定格式.pt。本文使用pycharm IDE完成这一步,工程目录结构如下: ![pycharm目录结构](https://img-blog.csdnimg.cn/d67266301c3f43bfa20d3585dc5fe836.png#pic_center 其中,vgg16bn_CIFAR10.pth和另一个pth文件是需要部署到手机上的模型,models.py是自己定义的网络结构。在此默认读者熟悉pytorch,对models.py不做赘述。 执行以下代码实现转换: import torch.utils.data.distributed '定义转化后的模型名称' model_ori_pt ='model_ori.pt' model_pruned_pt ='model_pruned.pt' '加载pytorch模型' model_ori = torch.load('vgg16bn_CIFAR10.pth') model_pruned = torch.load('vgg16bn_CIFAR10_pruned.pth') '模型在cpu上运行' device = torch.device('cpu') model_ori.to(device) model_pruned.to(device) model_ori.eval() model_pruned.eval() '定义输入图片的大小' input_tensor = torch.rand(1, 3, 32, 32) '转化模型并存储' mobile_ori = torch.jit.trace(model_ori, input_tensor) model_pruned = torch.jit.trace(model_pruned, input_tensor) mobile_ori.save(model_ori_pt) model_pruned.save(model_pruned_pt)请注意,让模型在cpu上,或cuda上执行eval()均可,但要保证模型与input_tensor在同一设备上,否则将运行出错。运行后,会得到model_ori.pt与model_pruned.pt两个文件,即可以用于android上的文件。此时目录结构如下: 首先,需要在本地安装Android Studio,安装流程建议参照: https://m.runoob.com/android/android-studio-install.html?ivk_sa=1024320u 然后打开Android Studio新建Empy Activity 点击Next。 点击Finsh。SDK建议选择7.0以往的安卓版本。**首次新建工程底部会长时间出现加载进度条,请耐心等待加载完成。**接下来,我们需要有一部手机调试工程,本文使用Android Studio自带的模拟器。首先点击顶部工具栏的Device Manager。
首先,新建assets文件夹,请不要直接新建,需右键app->Folder->Assets Folder。 http://www.cs.toronto.edu/~kriz/cifar.html 然后在gradle Scripts 文件夹中的build.gradle(Module :app)文件中的depencies里添加: implementation 'org.pytorch:pytorch_android:1.12.1' implementation 'org.pytorch:pytorch_android_torchvision:1.12.1'请注意**1.12.1是本文使用的pytorch版本,读者应该为对应的版本号。**然后点击工具栏下的sync now,再耐心等待运行按钮变绿。 然后右键java里的com.example.工程名 文件夹,New->Java Class。本文新建的类名是CIfarClassed,类内代码: package com.example.工程名; public class CifarClassed { public static String[] IMAGENET_CLASSES = new String[]{ "ddd", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", }; }最后打开java->com.example.工程名->MainActivity,删除原代码,用以下代码替代: package com.example.dnna; import android.content.Context; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.os.Bundle; import android.util.Log; import android.widget.ImageView; import android.widget.TextView; import org.pytorch.IValue; import org.pytorch.Module; import org.pytorch.Tensor; import org.pytorch.torchvision.TensorImageUtils; import org.pytorch.MemoryFormat; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import androidx.appcompat.app.AppCompatActivity; import com.example.dnna.CifarClassed; public class MainActivity extends AppCompatActivity { @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_main); Bitmap bitmap = null; Module module_ori = null; Module module_pruned = null; try { // creating bitmap from packaged into app android asset 'image.jpg', // app/src/main/assets/image.jpg bitmap = BitmapFactory.decodeStream(getAssets().open("x.png")); // loading serialized torchscript module from packaged into app android asset model.pt, // app/src/model/assets/model.pt module_ori = Module.load(assetFilePath(this, "model_ori.pt")); module_pruned = Module.load(assetFilePath(this, "model——pruned.pt")); } catch (IOException e) { Log.e("PytorchHelloWorld", "Error reading assets", e); finish(); } // showing image on UI ImageView imageView = findViewById(R.id.image); imageView.setImageBitmap(bitmap); // preparing input tensor final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST); // running the model long startTime_ori = System.currentTimeMillis(); final Tensor outputTensor_ori = module_ori.forward(IValue.from(inputTensor)).toTensor(); long endTime_ori = System.currentTimeMillis(); long InferenceTimeOri=endTime_ori - startTime_ori; long startTime_pruned = System.currentTimeMillis(); final Tensor outputTensor_pruned = module_pruned.forward(IValue.from(inputTensor)).toTensor(); long endTime_pruned = System.currentTimeMillis(); long InferenceTimePruned=endTime_pruned - startTime_pruned; // getting tensor content as java array of floats final float[] scores = outputTensor_ori.getDataAsFloatArray(); // searching for the index with maximum score float maxScore = -Float.MAX_VALUE; int maxScoreIdx = -1; for (int i = 0; i maxScore) { maxScore = scores[i]; maxScoreIdx = i; } } System.out.println(maxScoreIdx); String className = CifarClassed.IMAGENET_CLASSES[maxScoreIdx]; // showing className on UI TextView textView = findViewById(R.id.text); String tex="推理结果:"+className+"\n原始模型推理时间:"+InferenceTimeOri+"ms"+"\n剪枝模型推理时间:"+InferenceTimePruned+"ms"; textView.setText(tex); } /** * Copies specified asset to the file in /files app directory and returns this file absolute path. * * @return absolute file path */ public static String assetFilePath(Context context, String assetName) throws IOException { File file = new File(context.getFilesDir(), assetName); if (file.exists() && file.length() > 0) { return file.getAbsolutePath(); } try (InputStream is = context.getAssets().open(assetName)) { try (OutputStream os = new FileOutputStream(file)) { byte[] buffer = new byte[4 * 1024]; int read; while ((read = is.read(buffer)) != -1) { os.write(buffer, 0, read); } os.flush(); } return file.getAbsolutePath(); } } }运行效果如下图: 本文的主要流程是: 使用pytorch转化模型新建Android Studio工程与虚拟机修改Android Studio工程代码本人目前希望提升自己的博客撰写水平,如读者在实现过程中遇到困难,或在阅读本文时感到困惑,欢迎留言或添加我的QQ:1106295085。我将在周日下午回复,并积极修改本文。 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |