写点什么

带你了解 TensorFlow pb 模型常用处理方法

  • 2022 年 8 月 12 日
    中国香港
  • 本文字数:2844 字

    阅读完需:约 9 分钟

带你了解TensorFlow pb模型常用处理方法

本文分享自华为云社区《TensorFlow pb模型修改和优化》,作者:luchangli。


TensorFlow 模型训练完成后,通常会通过 frozen 过程保存一个最终的 pb 模型。保存的 pb 模型是以 GraphDef 数据结构保存的,可以序列化保存为二 pb 进制模型或者文本 pbtxt 模型。GraphDef 本质上是一个 DAG 有向无环图,里面主要是存放了一个算子 node list,每个算子具有名称,attr 等内容,以及通过 input 包含了 node 之间的连接关系。


整个 GraphDef 的输入节点是以 Placeholder 节点来标识的,模型参数权重通常是以 Const 节点来保存的。不同于 onnx,GraphDef 没有对输出进行标识,好处是可以通过 node_name:idx 来引用获取任意一个节点的输出,缺点是一般需要通过 netron 手动打开查看模型输出,或者通过代码分析没有输出节点的 node 作为模型输出节点。下面简单介绍下 pb 模型常用的一些处理方法。

pb 模型保存


# write pb modelwith tf.io.gfile.GFile(model_path, "wb") as f: f.write(graph_def.SerializeToString())# write pbtxt modeltf.io.write_graph(graph_def, os.path.dirname(model_path), os.path.basename(model_path))
复制代码

创建 node


from tensorflow.core.framework import attr_value_pb2from tensorflow.core.framework import node_def_pb2from tensorflow.python.framework import tensor_utilpld_node = node_def_pb2.NodeDef()pld_node.name = namepld_node.op = "Placeholder"shape = tf.TensorShape([None, 3, 256, 256])pld_node.attr["shape"].CopyFrom(attr_value_pb2.AttrValue(shape=shape.as_proto()))dtype = tf.dtypes.as_dtype("float32")pld_node.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(type=dtype.as_datatype_enum))# other commonly used settingnode.input.extend(in_node_names)node.attr["value"].CopyFrom(    attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto( np_array, np_array.type, np_array.shape)))
复制代码

构建模型和保存


import tensorflow as tfimport numpy as nptf.compat.v1.disable_eager_execution()tf.compat.v1.reset_default_graph()m = 200k = 256n = 128a_shape = [m, k]b_shape = [k, n]np.random.seed(0)input_np = np.random.uniform(low=0.0, high=1.0, size=a_shape).astype("float32")kernel_np = np.random.uniform(low=0.0, high=1.0, size=b_shape).astype("float32")# 构建模型pld1 = tf.compat.v1.placeholder(dtype="float32", shape=a_shape, name="input1")kernel = tf.constant(kernel_np, dtype="float32")feed_dict = {pld1: input_np}result_tf = tf.raw_ops.MatMul(a=pld1, b=kernel, transpose_a=False, transpose_b=False)with tf.compat.v1.Session() as sess:    results = sess.run(result_tf, feed_dict=feed_dict) print("results:", results)# 保存模型dump_model_name = "matmul_graph.pb"graph = tf.compat.v1.get_default_graph()graph_def = graph.as_graph_def()with tf.io.gfile.GFile(dump_model_name, "wb") as f: f.write(graph_def.SerializeToString())
复制代码


当然一般用其他方式而不是 raw_ops 构建模型。

pb 模型读取


from google.protobuf import text_formatgraph_def = tf.compat.v1.GraphDef()# read pb modelwith tf.io.gfile.GFile(model_path, "rb") as f: graph_def.ParseFromString(f.read())# read pbtxt modelwith open(model_path, "r") as pf: text_format.Parse(pf.read(), graph_def)
复制代码

node 信息打印


常用信息:


node.namenode.opnode.inputnode.device# please ref https://www.tensorflow.org/api_docs/python/tf/compat/v1/AttrValuenode.attr[attr_name].f # b, i, tensor, etc.# graph_def中node遍历:for node in graph_def.node: ##
复制代码


对于 node 的 input,一般用 node_name:idx 如 node_name:0 来表示输入来自上一个算子的第 idx 个输出。:0 省略则是默认为第 0 个输出。名称前面加^符号是控制边。这个 input 是一个字符串 list,这里面的顺序也对应这个 node 的各个输入的顺序。

创建 GraphDef 和添加 node


graph_def_n = tf.compat.v1.GraphDef()for node in graph_def_o.node: node_n = node_def_pb2.NodeDef() node_n.CopyFrom(node) graph_def_n.node.extend([node_n])# you probably need copy other value like version, etc. from old graphgraph_def_n.version = graph_def_o.versiongraph_def_n.library.CopyFrom(graph_def_o.library)graph_def_n.versions.CopyFrom(graph_def_o.versions)
复制代码


返回 graph_def_n


noonnx 模型往 graph 里面添加节点的 topo 排序要求

设置占位符的形状


参考前面创建 node 部分,通过修改 Placeholder 的 shape 属性。

模型形状推导


需要导入模型到 tf:tf.import_graph_def(graph_def, name='')。当然需要先设置正确的 pld 的 shape。


然后获取 node 的输出 tensor:graph.get_tensor_by_name(node_name + “:0”)。


最后可以从 tensor 里面获取 shape 和 dtype。

pb 模型图优化


思路一般比较简单:


1,子图连接关系匹配,比如要匹配 conv2d+bn+relu 这个 pattern 连接关系。由于每个 node 只保存其输入的 node 连接关系,要进行 DFS/BFS 遍历图一般需要每个 node 的输入输出,这可以首先读取所有的 node 连接关系并根据 input 信息同时创建一个 output 信息 map。


2,子图替换,先创建新的算子,再把旧的算子替换为新的算子。这个需要创建新的 node 或者直接修改原来的 node。旧的不要的算子可以创建个新图拷贝时丢弃,新的 node 可以直接 extend 到 graph_def。


3,如果替换为 TF 内置的算子,算子定义可以参考 tensorflow raw_ops 中的定义,但是有些属性(例如数据类型 attr “T”)没有列出来:https://www.tensorflow.org/api_docs/python/tf/raw_ops

当然也可以替换为自定义算子,这就需要用户开发和注册自定义算子:https://www.tensorflow.org/guide/create_op


如上所述,TensorFlow 的 pb 模型修改优化可以直接使用 python 代码实现,极大简化开发过程。当然 TensorFlow 也可以注册 grappler 和 post rewrite 图优化 pass 在 C++层面进行图优化,后者除了可以用于推理,也可以用于训练优化。

saved model 与 pb 模型的相互转换


可以参考:tensorflow 模型导出总结 - 知乎


saved model 保存的是一整个训练图,并且参数没有冻结。而只用于模型推理 serving 并不需要完整的训练图,并且参数不冻结无法进行转 TensorRT 等极致优化。当然也可以 saved_model->frozen pb->saved model 来同时利用两者的优点。

pb 转哆


使用 tf2onnx 库GitHub - onnx/tensorflow-onnx:将TensorFlow,Keras,Tensorflow.js和Tflite模型转换为ONNX


#!/bin/bashgraphdef=input_model.pbinputs=Placeholder_1:0,Placeholder_2:0outputs=output0:0,output1:0output=${graphdef}.onnxpython -m tf2onnx.convert \    --graphdef ${graphdef} \    --output ${output} \    --inputs ${inputs} \    --outputs ${outputs}\    --opset 12
复制代码


点击关注,第一时间了解华为云新鲜技术~

发布于: 刚刚阅读数: 3
用户头像

提供全面深入的云计算技术干货 2020.07.14 加入

华为云开发者社区,提供全面深入的云计算前景分析、丰富的技术干货、程序样例,分享华为云前沿资讯动态,方便开发者快速成长与发展,欢迎提问、互动,多方位了解云计算! 传送门:https://bbs.huaweicloud.com/

评论

发布
暂无评论
带你了解TensorFlow pb模型常用处理方法_人工智能_华为云开发者联盟_InfoQ写作社区