深度学习 - 浅谈 keras 的扩展性
1. 自定义 keras
keras 是一种深度学习的 API,能够快速实现你的实验。keras 也集成了很多预训练的模型,可以实现很多常规的任务,如图像分类。TensorFlow 2.0 之后 tensorflow 本身也变的很 keras 化。
另一方面,keras 表现出高度的模块化和封装性,所以有的人会觉得 keras 不易于扩展, 比如实现一种新的 Loss,新的网络层结构; 其实可以通过 keras 的基础模块进行快速的扩展,实现更新的算法。
本文就 keras 的扩展性,总结了对 layer,model 和 loss 的自定义。
2. 自定义 keras layers
layers 是 keras 中重要的组成部分,网络结构中每一个组成都要以 layers 来表现。keras 提供了很多常规的 layer,如 Convolution layers,pooling layers, activation layers, dense layers 等, 我们可以通过继承基础 layers 来扩展自定义的 layers。
2.1 base layer
layer 实了输入 tensor 和输出 tensor 的操作类,以下为 base layer 的 5 个方法,自定义 layer 只要重写这些方法就可以了。
init(): 定义自定义 layer 的一些属性
build(self, input_shape): 定义 layer 需要的权重 weights
call(self, *args, **kwargs):layer 具体的操作,会在调用自定义 layer 自动执行
get_config(self):layer 初始化的配置,是一个字典 dictionary。
compute_output_shape(self,input_shape):计算输出 tensor 的 shape
2.2 例子
3. 自定义 keras model
我们在定义完网络结构时,会把整个工作流放在keras.Model
, 进行compile()
, 然后通过fit()
进行训练过程。执行fit()
的时候,执行每个 batch size data 的时候,都会调用Model
中train_step(self, data)
当你需要自己控制训练过程的时候,可以重写Model
的train_step(self, data)
方法
4. 自定义 keras loss
keras 实现了交叉熵等常见的 loss,自定义 loss 对于使用 keras 来说是比较常见,实现各种魔改 loss,如 focal loss。
我们来看看 keras 源码中对 loss 实现
可以看出输入是 groud true y_true
和预测值y_pred
, 返回为计算 loss 的函数。自定义 loss 可以参照如此模式即可。
5. 总结
本文分享了 keras 的扩展功能,扩展功能其实也是实现 Keras 模块化的一种继承实现。
总结如下:
继承 Layer 实现自定义 layer, 记住
bulid()
call()
继续 Model 实现
train_step
定义训练过程,记住梯度计算tape.gradient(loss, trainable_vars)
,权重更新optimizer.apply_gradients
, 计算 evaluatecompiled_metrics.update_state(y, y_pred)
魔改 loss,记住 groud true
y_true
和预测值y_pred
输入,返回 loss function
版权声明: 本文为 InfoQ 作者【AIWeker】的原创文章。
原文链接:【http://xie.infoq.cn/article/f1ff758824ec324ff3501e828】。文章转载请联系作者。
评论