Penzai:来自 Deepmind 的 JAX 模型构建和可视化工具库
前言
JAX 是继 TensorFlow 之后,谷歌近年推出的一个较新的深度学习框架,具有计算速度快、支持大规模 GPU 集群等优点。JAX 的生态近年来不断丰富,在 JAX 工具库方面已经包括有 Flax、Equinox、Keras 等不少产品。
Penzai(“盆栽”)[1][2]是近期由来自谷歌 DeepMind 的作者推出的一款新的 JAX 工具库。Penzai 可以用来对已创建的 JAX 模型进行可视化展示、反向工程、消融分析和改进,也能用来创建新的 JAX 模型。
由 Penzai 创建的模型以及模型的可视化展示是什么样子的呢?接下来【算 AI】小编就通过示例来简单介绍一下。
示例一:模型的创建及可视化展示
【算 AI】小编首先使用 Penzai 创建了一个简单的 MLP(多层感知器),然后使用 Penzai 在 Jupyter 中输出了该 MLP 模型的结构。该模型结构的输出如下。
点击输出结果中的右箭头,可以展开更多的内容,例如参数、所包含的层等,如下图所示。
更多的功能和示例
Penzai 的官方文档[2]中介绍了 Penzai 的更多功能和示例,例如:
如何基于 Penzai 和 LoRA 修改已有的模型;
如何基于 Penzai 从零开始构建 Gemma 7B 模型;
Penzai 自带的神经网络库,类似于 Flax、Haiku、Keras、Equinox 等包含的神经网络库,用于搭建、编辑神经网络模型;
更多的交互式的、彩色的模型和数据可视化功能,等等。
以下再通过示例简单介绍一下 Penzai 对于多维数组的可视化展示。
示例二:N 维数组的可视化展示
Penzai 的可视化功能可以用来展示任意维度的数组。例如,下图是一个二维数组的可视化展示:
以下是一个四维数组的可视化展示:
在上图中,3 大行和 4 大列分别用来表示第 1 和第 2 维度,每个 5 乘 6 的小方格以及其中的颜色用来表示第 3 和第 4 维度、以及该四维数组中每个元素的值。
上述的四维数组也可以这样来展示:
在上图中,3 个横向间距较宽的分组表示第 1 维度,每 4 个横向间距较小的列表示第 2 维度,每个 5 乘 6 的小方格以及其中的颜色依然表示第 3 和第 4 维度、以及该四维数组中每个元素的值。
Penzai 的安装
安装 Penzai 的过程比较简单。首先安装 JAX,同时需要确保 Python 的版本至少是 3.10,然后执行 pip install penzai 就可以了。
官方的 Getting Started 文档中的示例代码遗漏了一行 import jax 命令,运行前自己加上就可以。
其它信息
Penzai 目前还是 0.1.x 版本,未来在功能、接口等方面可能会有变化。另外,尽管 Penzai 由谷歌员工创建,但 Penzai 并非谷歌的官方产品。
Penzai 的授权协议采用的是 Apache 2.0。
参考资料
版权声明: 本文为 InfoQ 作者【算AI】的原创文章。
原文链接:【http://xie.infoq.cn/article/c2b8ee1b66616da106f41d1d4】。文章转载请联系作者。
评论