写点什么

用 Tensorflow.js 做了一个动漫分类的功能(二)

作者:北桥苏
  • 2023-05-15
    广东
  • 本文字数:3113 字

    阅读完需:约 10 分钟

前言:

​ 前面已经通过采集拿到了图片,并且也手动对图片做了标注。接下来就要通过 Tensorflow.js 基于 mobileNet 训练模型,最后就可以实现在采集中对图片进行自动分类了。


​ 这种功能在应用场景里就比较多了,比如图标素材站点,用户通过上传一个图标,系统会自动匹配出相似的图标,还有二手平台,用户通过上传闲置物品图片,平台自动给出分类等,这些也都是前期对海量图片进行了标注训练而得到一个损失率极低的模型。下面就通过简答的代码实现一个小的动漫分类。


环境:

Node


Http-Server


Parcel


Tensorflow


编码:

\1. 训练模型


1.1. 创建项目,安装依赖包


npm install @tensorflow/tfjs --legacy-peer-depsnpm install @tensorflow/tfjs-node-gpu --legacy-peer-deps
复制代码


1.2. 全局安装 Http-Server


npm install i http-server
复制代码


1.3. 下载 mobileNet 模型文件 (网上有下载)


1.4. 根目录下启动 Http 服务 (开启跨域),用于 mobileNet 和训练结果的模型可访问


http-server --cors -p 8080
复制代码



1.5. 创建训练执行脚本 run.js


const tf = require('@tensorflow/tfjs-node-gpu');
const getData = require('./data');const TRAIN_PATH = './动漫分类/train';const OUT_PUT = 'output';const MOBILENET_URL = 'http://127.0.0.1:8080/data/mobilenet/web_model/model.json';
(async () => { const { ds, classes } = await getData(TRAIN_PATH, OUT_PUT); console.log(ds, classes); //引入别人训练好的模型 const mobilenet = await tf.loadLayersModel(MOBILENET_URL); //查看模型结构 mobilenet.summary();
const model = tf.sequential(); //截断模型,复用了86个层 for (let i = 0; i < 86; ++i) { const layer = mobilenet.layers[i]; layer.trainable = false; model.add(layer); } //降维,摊平数据 model.add(tf.layers.flatten()); //设置全连接层 model.add(tf.layers.dense({ units: 10, activation: 'relu'//设置激活函数,用于处理非线性问题 }));
model.add(tf.layers.dense({ units: classes.length, activation: 'softmax'//用于多分类问题 })); //设置损失函数,优化器 model.compile({ loss: 'sparseCategoricalCrossentropy', optimizer: tf.train.adam(), metrics:['acc'] });
//训练模型 await model.fitDataset(ds, { epochs: 20 }); //保存模型 await model.save(`file://${process.cwd()}/${OUT_PUT}`);})();
复制代码


1.6. 创建图片与 Tensor 转换库 data.js


const fs = require('fs');const tf = require("@tensorflow/tfjs-node-gpu");
const img2x = (imgPath) => { const buffer = fs.readFileSync(imgPath); //清除数据 return tf.tidy(() => { //把图片转成tensor const imgt = tf.node.decodeImage(new Uint8Array(buffer), 3); //调整图片大小 const imgResize = tf.image.resizeBilinear(imgt, [224, 224]); //归一化 return imgResize.toFloat().sub(255 / 2).div(255 / 2).reshape([1, 224, 224, 3]); });}
const getData = async (traindir, output) => { let classes = fs.readdirSync(traindir, 'utf-8'); fs.writeFileSync(`./${output}/classes.json`, JSON.stringify(classes)); const data = []; classes.forEach((dir, dirIndex) => { fs.readdirSync(`${traindir}/${dir}`) .filter(n => n.match(/jpg$/)) .slice(0, 1000) .forEach(filename => { const imgPath = `${traindir}/${dir}/${filename}`;
data.push({ imgPath, dirIndex }); }); });
console.log(data);
//打乱训练顺序,提高准确度 tf.util.shuffle(data);
const ds = tf.data.generator(function* () { const count = data.length; const batchSize = 32; for (let start = 0; start < count; start += batchSize) { const end = Math.min(start + batchSize, count); console.log('当前批次', start); yield tf.tidy(() => { const inputs = []; const labels = []; for (let j = start; j < end; ++j) { const { imgPath, dirIndex } = data[j]; const x = img2x(imgPath); inputs.push(x); labels.push(dirIndex); } const xs = tf.concat(inputs); const ys = tf.tensor(labels); return { xs, ys }; }); } });
return { ds, classes };}
module.exports = getData;
复制代码


1.7. 运行执行文件


node run.js
复制代码



\2. 调用模型


2.1. 全局安装 parcel


npm install i parcel
复制代码


2.2. 创建页面 index.html


<script src="script.js"></script><input type="file" onchange="predict(this.files[0])"><br>
复制代码


2.3. 创建模型调用预测脚本 script.js


import * as tf from '@tensorflow/tfjs';import { img2x, file2img } from './utils';
const MODEL_PATH = 'http://127.0.0.1:8080/t7';const CLASSES = ["假面骑士","奥特曼","海贼王","火影忍者","龙珠"];

window.onload = async () => { const model = await tf.loadLayersModel(MODEL_PATH + '/output/model.json');
window.predict = async (file) => { const img = await file2img(file); document.body.appendChild(img); const pred = tf.tidy(() => { const x = img2x(img); return model.predict(x); });
const index = pred.argMax(1).dataSync()[0]; console.log(pred.argMax(1).dataSync());
let predictStr = ""; if (typeof CLASSES[index] == 'undefined') { predictStr = BRAND_CLASSES[index]; } else { predictStr = CLASSES[index]; }
setTimeout(() => { alert(`预测结果:${predictStr}`); }, 0); };};
复制代码


2.4. 创建图片 tensor 格式转换库 utils.js


import * as tf from '@tensorflow/tfjs';
export function img2x(imgEl){ return tf.tidy(() => { const input = tf.browser.fromPixels(imgEl) .toFloat() .sub(255 / 2) .div(255 / 2) .reshape([1, 224, 224, 3]); return input; });}
export function file2img(f) { return new Promise(resolve => { const reader = new FileReader(); reader.readAsDataURL(f); reader.onload = (e) => { const img = document.createElement('img'); img.src = e.target.result; img.width = 224; img.height = 224; img.onload = () => resolve(img); }; });}
复制代码


2.5. 打包项目并运行


parcel index.html
复制代码



2.6. 运行效果




注意:

\1. 模型训练过程报错


Input to reshape is a tensor with 50176 values, but the requested shape has 150528


1.1. 原因


张量 reshape 不对,实际输入元素个数与所需矩阵元素个数不一致,就是采集过来的图片有多种图片格式,而不同格式的通道不同 (jpg3 通道,png4 通道,灰色图片 1 通道),在将图片转换 tensor 时与代码里的张量形状不匹配。


1.2. 解决方法


一种方法是删除灰色或 png 图片,其二是修改代码 tf.node.decodeImage (new Uint8Array (buffer), 3)




用户头像

北桥苏

关注

公众号:ZERO开发 2023-05-08 加入

专注后端实战技术分享,不限于PHP,Python,JavaScript, Java等语言,致力于给猿友们提供有价值,有干货的内容。

评论

发布
暂无评论
用 Tensorflow.js 做了一个动漫分类的功能(二)_JavaScript_北桥苏_InfoQ写作社区