写点什么

如何将训练好的 Python 模型给 JavaScript 使用?

作者:北桥苏
  • 2023-05-15
    广东
  • 本文字数:4369 字

    阅读完需:约 14 分钟

前言


  从前面的 Tensorflow 环境搭建到目标检测模型迁移学习,已经完成了一个简答的扑克牌检测器,不管是从图片还是视频都能从画面中识别出有扑克的目标,并标识出扑克点数。但是,我想在想让他放在浏览器上可能实际使用,那么要如何让 Tensorflow 模型转换成 web 格式的呢?接下来将从实践的角度详细介绍一下部署方法!



环境


  • Windows10

  • Anaconda3

  • TensorFlow.js converter


converter 介绍


  converter 全名是 TensorFlow.js Converter,他可以将 TensorFlow GraphDef 模型(通过 Python API 创建的,可以先理解为 Python 模型) 转换成 Tensorflow.js 可读取的模型格式(json 格式), 用于在浏览器上对指定数据进行推算。



converter 安装


  为了不影响前面目标检测训练环境,这里我用 conda 创建了一个新的 Python 虚拟环境,Python 版本 3.6.8。在安装转换器的时候,如果当前环境没有 Tensorflow,默认会安装与 TF 相关的依赖,只需要进入指定虚拟环境,输入以下命令。


pip install tensorflowjs
复制代码



converter 用法


tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve ./saved_model ./web_model
复制代码


\1. 产生的文件(生成的 web 格式模型)


转换器命令执行后生产两种文件,分别是 model.json (数据流图和权重清单)和 group1-shard*of* (二进制权重文件)


\2. 输入的必要条件(命令参数和选项[带--为选项])


converter 转换指令后面主要携带四个参数,分别是输入模型的格式,输出模型的格式,输入模型的路径,输出模型的路径,更多帮助信息可以通过以下命令查看,另附命令分解图。


tensorflowjs_converter --help
复制代码



2.1. --input_format


要转换的模型的格式,SavedModel 为 tf_saved_model, frozen model 为 tf_frozen_model, session bundle 为 tf_session_bundle, TensorFlow Hub module 为 tf_hub,Keras HDF5 为 keras。


2.2. --output_format


输出模型的格式, 分别有 tfjs_graph_model (tensorflow.js 图模型,保存后的 web 模型没有了再训练能力,适合 SavedModel 输入格式转换),tfjs_layers_model(tensorflow.js 层模型,具有有限的 Keras 功能,不适合 TensorFlow SavedModels 转换)。


2.3. input_path


saved model, session bundle 或 frozen model 的完整的路径,或 TensorFlow Hub 模块的路径。


2.4. output_path


输出文件的保存路径。


2.5. --saved_model_tags


只对 SavedModel 转换用的选项:输入需要加载的 MetaGraphDef 相对应的 tag,多个 tag 请用逗号分隔。默认为 serve


2.6. --signature_name


对 TensorFlow Hub module 和 SavedModel 转换用的选项:对应要加载的签名,默认为default


2.7. --output_node_names


输出节点的名字,每个名字用逗号分离。


\3. 常用的两组命令行


1. covert ``from` `saved_model` `tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve ./saved_model ./web_model` `2. convert ``from` `frozen_model``tensorflowjs_converter --input_format=tf_frozen_model --output_node_names=``'num_detections,detection_boxes,detection_scores,detection_classes'` `./frozen_inference_graph.pb ./web_modelk
复制代码


开始实践


\1. 找到通过 export_inference_graph.py 导出的模型


导出的模型在项目的 inference_graph 文件夹(models\research\object_detection)里,frozen_inference_graph.pb 是 tf_frozen_model 输入格式需要的,而 saved_model 文件夹就是 tf_saved_model 格式。在当前目录下新建 web_model 目录,用于存储转换后的 web 格式的模型。



\2. 开始转换


在当前虚拟环境下,进入到 inference_graph 目录下,输入以下命令,之后就会在 web_model 生成一个 json 文件和多个权重文件。


tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve ./saved_model ./web_model
复制代码



\3. 浏览器端部署


3.1. 创建一个前端项目,将 web_model 放入其中。



3.2.编写代码


<!doctype html>``<head>`` ``<link rel=``"stylesheet"` `href=``"tfjs-examples.css"` `/>`` ``<style>`` ``canvas {outline: orange 2px solid; margin: 10px 0;}`` ``</style>``</head>` `<body>`` ``<div ``class``=``"tfjs-example-container centered-container"``>``  ``<section ``class``=``'title-area'``>``   ``<h1>赌圣2023</h1>``  ``</section>``  ``<p ``class``=``'section-head'``>模型描述</p>``  ``<p>我看你怎么出老千!</p>``  ``<p ``class``=``'section-head'``>模型状态</p>``  ``<div id=``"status"``>加载模型中...</div>``  ``<div>``   ``<p ``class``=``'section-head'``>效果展示</p>``   ``<p></button><input type=``"file"` `accept=``"image/*"` `id=``"test"``/></p>``   ``<canvas id=``"data-canvas"` `width=``"300"` `height=``"1100"``></canvas>``  ``</div>`` ``</div>` `</body>` `<script src=``"https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"``></script>` `<script>`` ``const` `canvas = document.getElementById(``'data-canvas'``);`` ``const` `status = document.getElementById(``'status'``);`` ``const` `testModel = document.getElementById(``'test'``);` ` ``const` `BOUNDING_BOX_LINE_WIDTH = 3;`` ``const` `BOUNDING_BOX_STYLE1 = ``'rgb(0,0,255)'``;`` ``const` `BOUNDING_BOX_STYLE2 = ``'rgb(0,255,0)'``;` ` ``async function init() {` `  ``const` `LOCAL_MODEL_PATH = ``'./web_model/model.json'``;` `  ``// 将本地模型保存到浏览器``  ``// tf.sequential().save` `  ``// 加载本地模型``  ``let` `model;``  ``try` `{``   ``model = await tf.loadGraphModel(LOCAL_MODEL_PATH);``   ``testModel.disabled = ``false``;``   ``status.textContent = ``'成功加载本地模型!请亮出你的牌吧'``;``   ` `   ``// 默认扑克牌``   ``runAndVisualizeInference(``'./cam_image39.jpg'``, model)``   ` `  ``} ``catch` `(err) {``   ``console.log(``'加载本地模型错误:'``, err);``   ``status.textContent = ``'加载本地模型失败'``;``  ``}` `  ``testModel.addEventListener(``'change'``, (e) => {``   ``runAndVisualizeInference(e, model)``  ``});``}` `async function runAndVisualizeInference(e, model) {` ` ``if` `(``typeof` `e === ``'string'``) {``  ``await ``new` `Promise((resolve, reject) => {``   ``// 图片显示在canvas中``   ``var` `img = ``new` `Image;``   ``img.src = e;``   ``img.onload = function () { ``// 必须onload之后再画``    ``let` `w = 500;``    ``let` `h = img.height/img.width*500;``    ``canvas.width = w;``    ``canvas.height = h;``    ``var` `ctx = canvas.getContext(``'2d'``);``    ``ctx.drawImage(img,0,0,w,h);``    ``resolve();``   ``}``  ``})`` ``} ``else` `{` `  ``// 上传图片并显示在canvas中``  ``var` `file = e.target.files[0]; ``  ``if` `(!/image\/\w+/.test(file.type)) {``   ``alert(``"请确保文件为图像类型"``);``   ``return` `false``;``  ``}``  ``var` `reader = ``new` `FileReader();``  ``reader.readAsDataURL(file); ``// 转化成base64数据类型``  ``await ``new` `Promise((resolve, reject) => {``   ``reader.onload = function (e) {``    ``// 图片显示在canvas中``    ``var` `img = ``new` `Image;``    ``img.src = ``this``.result;``    ``img.onload = function () { ``// 必须onload之后再画``     ``let` `w = 500;``     ``let` `h = img.height/img.width*500;``     ``canvas.width = w;``     ``canvas.height = h;``     ``var` `ctx = canvas.getContext(``'2d'``);``     ``ctx.drawImage(img,0,0,w,h);``     ``resolve();``    ``}``   ``}``  ``})`` ``}` ` ``// 模型输入处理`` ``let` `image = tf.browser.fromPixels(canvas);`` ``const` `t4d = image.expandDims(0);` ` ``const` `outputDim = [``  ``'num_detections'``, ``'detection_boxes'``, ``'detection_scores'``,``  ``'detection_classes'`` ``];`` ` ` ``const` `labelMap = {``  ``1: ``'九点'``,``  ``2: ``'十点'``,``  ``3: ``'Jack'``,``  ``4: ``'Queen'``,``  ``5: ``'King'``,``  ``6: ``'Ace'`` ``}`` ` ` ``let` `modelOut = {}, boxes = [], w = canvas.width, h = canvas.height;`` ``console.log(model)`` ` ` ``for` `(``const` `dim of outputDim) {``  ``let` `tensor = await model.executeAsync({``   ``'image_tensor'``: t4d``  ``}, `${dim}:0`);``  ``modelOut[dim] = await tensor.data();`` ``}`` ``console.log(modelOut)`` ` ` ``for` `(``let` `i=0; i<modelOut[``'detection_scores'``].length; i++) {``  ``const` `score = modelOut[``'detection_scores'``][i];`` ` `  ``if` `(score < 0.5) ``break``; ``// 置信度过滤`` ` `  ``boxes.push({``   ``ymin: modelOut[``'detection_boxes'``][i*4]*h,``   ``xmin: modelOut[``'detection_boxes'``][i*4+1]*w,``   ``ymax: modelOut[``'detection_boxes'``][i*4+2]*h,``   ``xmax: modelOut[``'detection_boxes'``][i*4+3]*w,``   ``label: labelMap[modelOut[``'detection_classes'``][i]],``  ``})`` ``}`` ` ` ``console.log(boxes)` ` ``// 可视化检测框`` ``drawBoundingBoxes(canvas, boxes);` ` ``// 张量运行内存清除`` ``tf.dispose([image, modelOut]);``}` `function drawBoundingBoxes(canvas, predictBoundingBoxArr) {`` ``for` `(``const` `box of predictBoundingBoxArr) {``  ``let` `left = box.xmin;``  ``let` `right = box.xmax;``  ``let` `top = box.ymin;``  ``let` `bottom = box.ymax;` `  ``const` `ctx = canvas.getContext(``'2d'``);``  ``ctx.beginPath();``  ``ctx.strokeStyle = box.label===``'ZERO_DEV'``?BOUNDING_BOX_STYLE1:BOUNDING_BOX_STYLE2;``  ``ctx.lineWidth = BOUNDING_BOX_LINE_WIDTH;``  ``ctx.moveTo(left, top);``  ``ctx.lineTo(right, top);``  ``ctx.lineTo(right, bottom);``  ``ctx.lineTo(left, bottom);``  ``ctx.lineTo(left, top);``  ``ctx.stroke();` `  ``ctx.font = ``'24px Arial bold'``;``  ``ctx.fillStyle = box.label===``'zfc'``?BOUNDING_BOX_STYLE2:BOUNDING_BOX_STYLE1;``  ``ctx.fillText(box.label, left+8, top+8);`` ``}``}` `init();` `</script>
复制代码


3.3. 运行结果





用户头像

北桥苏

关注

公众号:ZERO开发 2023-05-08 加入

专注后端实战技术分享,不限于PHP,Python,JavaScript, Java等语言,致力于给猿友们提供有价值,有干货的内容。

评论

发布
暂无评论
如何将训练好的Python模型给JavaScript使用?_Python_北桥苏_InfoQ写作社区