写点什么

当大火的文图生成模型遇见知识图谱,AI 画像趋近于真实世界

  • 2022-11-10
    浙江
  • 本文字数:6323 字

    阅读完需:约 21 分钟

当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界

导读

用户生成内容(User Generated Content,UGC)是互联网上多模态内容的重要组成部分,UGC 数据级的不断增长促进了各大多模态内容平台的繁荣。在海量多模态数据和深度学习大模型的加持下,AI 生成内容(AI Generated Content,AIGC)呈现出爆发性增长趋势。其中,文图生成(Text-to-image Generation)任务是流行的跨模态生成任务,旨在生成与给定文本对应的图像。典型的文图模型例如 OpenAI 开发的 DALL-E 和 DALL-E2。近期,业界也训练出了更大、更新的文图生成模型,例如 Google 提出的 Parti 和 Imagen,基于扩散模型的 Stable Diffusion 等。


然而,上述模型一般不能用于处理中文的需求,而且上述模型的参数量庞大,很难被开源社区的广大用户直接用来 Fine-tune 和推理。此外,文图生成模型的训练过程对于知识的理解比较缺乏,容易生成反常识内容。本次,EasyNLP 开源框架在先前推出的基于 Transformer 的文图生成模型(看这里)基础上,进一步推出了融合丰富知识图谱知识的文图生成模型 ARTIST,能在知识图谱的指引上,生成更加符合常识的图片。我们在中文文图生成评测基准 MUGE 上评测了 ARTIST 的生成效果,其生成效果名列榜单第一。我们也向开源社区免费开放了知识增强的中文文图生成模型的 Checkpoint,以及相应 Fine-tune 和推理接口。用户可以在我们开放的 Checkpoint 基础上进行少量领域相关的微调,在不消耗大量计算资源的情况下,就能一键进行各种艺术创作。


EasyNLP(https://github.com/alibaba/EasyNLP)是阿⾥云机器学习 PAI 团队基于 PyTorch 开发的易⽤且丰富的中⽂NLP 算法框架,⽀持常⽤的中⽂预训练模型和⼤模型落地技术,并且提供了从训练到部署的⼀站式 NLP 开发体验。EasyNLP 提供了简洁的接⼝供⽤户开发 NLP 模型,包括 NLP 应⽤ AppZoo 和预训练 ModelZoo,同时提供技术帮助⽤户⾼效的落地超⼤预训练模型到业务。由于跨模态理解需求的不断增加,EasyNLP 也⽀持各种跨模态模型,特别是中⽂领域的跨模态模型,推向开源社区,希望能够服务更多的 NLP 和多模态算法开发者和研 究者,也希望和社区⼀起推动 NLP /多模态技术的发展和模型落地。


本⽂简要介绍 ARTIST 的技术解读,以及如何在 EasyNLP 框架中使⽤ARTIST 模型。

ARTIST 模型详解

ARTIST 模型的构建基于 Transformer 模型 ,将文图生成任务分为两个阶段进行,第一阶段是通过 VQGAN 模型对图像进行矢量量化,即对于输入的图像,通过编码器将图像编码为定长的离散序列,解码阶段是以离散序列作为输入,输出重构图。第二阶段是将文本序列和编码后的图像序列作为输入,利用 GPT 模型学习以文本序列为条件的图像序列生成。为了增强模型先验,我们设计了一个 Word Lattice Fusion Layer,将知识图谱中的的实体知识引入模型,辅助图像中对应实体的生成,从而使得生成的图像的实体信息更加精准。下图是 ARTIST 模型的系统框图,以下从文图生成总体流程和知识注入两方面介绍本方案。



第一阶段:基于 VQGAN 的图像矢量量化

在 VQGAN 的训练阶段,我们利用数据中的图片,以图像重构为任务目标,训练一个图像词典的 codebook,其中,这一 codebook 保存每个 image token 的向量表示。实际操作中,对于一张图片,通过 CNN Encoder 编码后得到中间特征向量,再对特征向量中的每个编码位置寻找 codebook 中距离最近的表示,从而将图像转换成由 codebook 中的 imaga token 表示的离散序列。第二阶段中,GPT 模型会以文本为条件生成图像序列,该序列输入到 VQGAN Decoder,从而重构出一张图像。

第二阶段:以文本序列为输入利用 GPT 生成图像序列

为了将知识图谱中的知识融入到文图生成模型中,我们首先通过 TransE 对中文知识图谱 CN-DBpedia 进行了训练,得到了知识图谱中的实体表示。在 GPT 模型训练阶段,对于文本输入,首先识别出所有的实体,然后将已经训练好的实体表示和 token embedding 进行结合,增强实体表示。但是,由于每个文本 token 可能属于多个实体,如果将多个实体的表示全都引入模型,可能会造成知识噪声问题。所以我们设计了实体表示交互模块,通过计算每个实体表示和 token embedding 的交互,为所有实体表示加权,有选择地进行知识注入。特别地,我们计算每个实体表征对对于当前 token embedding 的重要性,通过内积进行衡量,然后将实体表示的加权平均值注入到当前 token embedding 中,计算过程如下:



得到知识注入的 token embedding 后,我们通过构建具有 layer norm 的 self-attention 网络,构建基于 Transformer 的 GPT 模型,过程如下:



在 GPT 模型的训练阶段,将文本序列和图像序列拼接作为输入,假设文本序列为 w, 生成图像的 imaga token 表示的离散序列概率如下所示:



最后,模型通过最大化图像部分的负对数似然来训练,得到模型参数的值。

ARTIST 模型效果

标准数据集评测结果

我们在多个中文数据集上评估了 ARTIST 模型的效果,这些数据集的统计数据如下所示:



在 Baseline 方面,我们考虑两种情况:zero-shot learning 和标准 fine-tuning。我们将 40 亿参数的中文 CogView 模型作为 zero-shot learner,我们也考虑两个模型规模和 ARTIST 模型规模相当的模型,分别为开源的 DALL-E 模型和 OFA 模型。实验数据如下所示:



从上可以看出,我们的模型在参数量很小的情况(202M)下也能获得较好的图文生成效果。为了衡量注入知识的有效性,我们进一步进行了相关评测,将知识模块移除,实验效果如下:



上述结果可以清楚地看出知识注入的作用。

案例分析

为了更加直接地比较不同场景下,ARTIST 和 baseline 模型生成图像质量对比,我们展示了电商商品场景和自然风光场景下各个模型生成图像的效果,如下图:


电商场景效果对比


自然风光场景效果对比


上图可以看出 ARTIST 生成图像质量的优越性。我们进一步比较我们先前公开的模型(看这里)和具有丰富知识的 ARTIST 模型的效果。在第一个示例“手工古风复原款发钗汉服配饰宫廷发簪珍珠头饰发冠”中,原始生成的结果主要突出了珍珠发冠这个物体。在 ARTIST 模型中,“古风”等词的知识注入过程使得模型生成结果会更偏向于古代中国的珍珠发簪。


第二个示例为“一颗绿色的花椰菜在生长”。由于模型在训练时对“花椰菜”物体样式掌握不够,当不包含知识注入模块时,模型根据“绿色”和“菜”的提示生成了有大片绿叶的单株植物。在 ARTIST 模型中,生成的物体更接近于形如花椰菜的椭圆形的植物。

ARTIST 模型在 MUGE 榜单的评测结果

MUGE(Multimodal Understanding and Generation Evaluation,链接)是业界首个大规模中文多模态评测基准,其中包括基于文本的图像生成任务。我们使用本次推出的 ARTIST 模型在中文 MUGE 评测榜单上验证了前述文图生成模型的效果。从下图可见,ARTIST 模型生成的图像在 FID 指标(Frechet Inception Distance,值越低表示生成图像质量越好)上超越了榜单上的其他结果。

ARTIST 模型的实现

在 EasyNLP 框架中,我们在模型层构建了 ARTIST 模型的 Backbone,其主要是 GPT,输入分别是 token id 和包含的实体的 embedding,输出是图片各个 patch 对应的离散序列。其核⼼代码如下所示:

# in easynlp/appzoo/text2image_generation/model.py
# initself.transformer = GPT_knowl(self.config)
# forwardx = inputs['image']c = inputs['text']words_emb = inputs['words_emb']
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)# one step to produce the logits_, z_indices = self.encode_to_z(x) c_indices = c
cz_indices = torch.cat((c_indices, a_indices), dim=1)
# make the predictionlogits, _ = self.transformer(cz_indices[:, :-1], words_emb, flag=True)# cut off conditioning outputs - output i corresponds to p(z_i | z_{<i}, c)logits = logits[:, c_indices.shape[1]-1:]
复制代码

在数据预处理过程中,我们需要获得当前样本的输入文本和实体 embedding,从而计算得到 words_emb:

# in easynlp/appzoo/text2image_generation/data.py
# preprocess word_matrixwords_mat = np.zeros([self.entity_num, self.text_len], dtype=np.int)if len(lex_id) > 0: ents = lex_id.split(' ')[:self.entity_num] pos_s = [int(x) for x in pos_s.split(' ')] pos_e = [int(x) for x in pos_e.split(' ')] ent_pos_s = pos_s[token_len:token_len+self.entity_num] ent_pos_e = pos_e[token_len:token_len+self.entity_num]
for i, ent in enumerate(ents): words_mat[i, ent_pos_s[i]:ent_pos_e[i]+1] = entencoding['words_mat'] = words_mat
# in batch_fnwords_mat = torch.LongTensor([example['words_mat'] for example in batch])words_emb = self.embed(words_mat)
复制代码

ARTIST 模型使⽤教程

以下我们简要介绍如何在 EasyNLP 框架使⽤ARTIST 模型。

安装 EasyNLP

⽤户可以直接参考GitHubhttps://github.com/alibaba/EasyNLP)上的说明安装 EasyNLP 算法框架。

数据准备

  1. 准备自己的数据,将 image 编码为 base64 形式:ARTIST 在具体领域应用需要 finetune, 需要用户准备下游任务的训练与验证数据,为 tsv 文件。这⼀⽂件包含以制表符\t 分隔的三列(idx, text, imgbase64),第一列是文本编号,第二列是文本,第三列是对应图片的 base64 编码。样例如下:

64b4109e34a0c3e7310588c00fc9e157	韩国可爱日系袜子女中筒袜春秋薄款纯棉学院风街头卡通兔子长袜潮	iVBORw0KGgoAAAAN...MAAAAASUVORK5CYII=
复制代码

下列⽂件已经完成预处理,可⽤于训练和测试:

https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_train.tsvhttps://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_val.tsvhttps://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/T2I_test.tsv
复制代码
  1. 将输入数据与 lattice、entity 位置信息拼接到一起:输出格式为以制表符\t 分隔的几列(idx, text, lex_ids, pos_s, pos_e, seq_len, [Optional] imgbase64)

# 下载entity to entity_id映射表wget wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/entity2id.txt
python examples/text2image_generation/preprocess_data_knowl.py \ --input_file ./tmp/T2I_train.tsv \ --entity_map_file ./tmp/entity2id.txt \ --output_file ./tmp/T2I_knowl_train.tsv
python examples/text2image_generation/preprocess_data_knowl.py \ --input_file ./tmp/T2I_val.tsv \ --entity_map_file ./tmp/entity2id.txt \ --output_file ./tmp/T2I_knowl_val.tsv
python examples/text2image_generation/preprocess_data_knowl.py \ --input_file ./tmp/T2I_test.tsv \ --entity_map_file ./tmp/entity2id.txt \ --output_file ./tmp/T2I_knowl_test.tsv
复制代码

ARTIST 文图生成微调和预测示例

在文图生成任务中,我们对 ARTIST 进行微调,之后用于微调后对模型进行预测。相关示例代码如下:

# 下载entity_id与entity_vector的映射表wget -P ./tmp https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/artist_text2image/entity2vec.pt
# finetunepython -m torch.distributed.launch $DISTRIBUTED_ARGS examples/text2image_generation/main_knowl.py \ --mode=train \ --worker_gpu=1 \ --tables=./tmp/T2I_knowl_train.tsv,./tmp/T2I_knowl_val.tsv \ --input_schema=idx:str:1,text:str:1,lex_id:str:1,pos_s:str:1,pos_e:str:1,token_len:str:1,imgbase64:str:1, \ --first_sequence=text \ --second_sequence=imgbase64 \ --checkpoint_dir=./tmp/artist_model_finetune \ --learning_rate=4e-5 \ --epoch_num=2 \ --random_seed=42 \ --logging_steps=100 \ --save_checkpoint_steps=200 \ --sequence_length=288 \ --micro_batch_size=8 \ --app_name=text2image_generation \ --user_defined_parameters=' pretrain_model_name_or_path=alibaba-pai/pai-artist-knowl-base-zh entity_emb_path=./tmp/entity2vec.pt size=256 text_len=32 img_len=256 img_vocab_size=16384 '
# predictpython -m torch.distributed.launch $DISTRIBUTED_ARGS examples/text2image_generation/main_knowl.py \ --mode=predict \ --worker_gpu=1 \ --tables=./tmp/T2I_knowl_test.tsv \ --input_schema=idx:str:1,text:str:1,lex_id:str:1,pos_s:str:1,pos_e:str:1,token_len:str:1, \ --first_sequence=text \ --outputs=./tmp/T2I_outputs_knowl.tsv \ --output_schema=idx,text,gen_imgbase64 \ --checkpoint_dir=./tmp/artist_model_finetune \ --sequence_length=288 \ --micro_batch_size=8 \ --app_name=text2image_generation \ --user_defined_parameters=' entity_emb_path=./tmp/entity2vec.pt size=256 text_len=32 img_len=256 img_vocab_size=16384 max_generated_num=4 '
复制代码

在阿里云机器学习平台 PAI 上使用 Transformer 实现文图生成

PAI-DSW(Data Science Workshop)是阿里云机器学习平台 PAI 开发的云上 IDE,面向不同水平的开发者,提供了交互式的编程环境(文档)。在 DSW Gallery 中,提供了各种 Notebook 示例,方便用户轻松上手 DSW,搭建各种机器学习应用。我们也在 DSW Gallery 中上架了使用 Transformer 模型进行中文文图生成的 Sample Notebook(见下图),欢迎大家体验!



未来展望

在这一期的工作中,我们在 EasyNLP 框架中扩展了基于 Transformer 的中文文图生成功能,同时开放了模型的 Checkpoint,方便开源社区用户在资源有限情况下进行少量领域相关的微调,进行各种艺术创作。在未来,我们计划在 EasyNLP 框架中推出更多相关模型,敬请期待。我们也将在 EasyNLP 框架中集成更多 SOTA 模型(特别是中文模型),来支持各种 NLP 和多模态任务。此外,阿里云机器学习 PAI 团队也在持续推进中文多模态模型的自研工作,欢迎用户持续关注我们,也欢迎加入我们的开源社区,共建中文 NLP 和多模态算法库!

Github 地址:https://github.com/alibaba/EasyNLP

Reference

  1. Chengyu Wang, Minghui Qiu, Taolin Zhang, Tingting Liu, Lei Li, Jianing Wang, Ming Wang, Jun Huang, Wei Lin. EasyNLP: A Comprehensive and Easy-to-use Toolkit for Natural Language Processing. EMNLP 2022

  2. Tingting Liu*, Chengyu Wang*, Xiangru Zhu, Lei Li, Minghui Qiu, Ming Gao, Yanghua Xiao, Jun Huang. ARTIST: A Transformer-based Chinese Text-to-Image Synthesizer Digesting Linguistic and World Knowledge. EMNLP 2022

  3. Aditya Ramesh, Mikhail Pavlov, Gabriel Goh, Scott Gray, Chelsea Voss, Alec Radford, Mark Chen, Ilya Sutskever. Zero-Shot Text-to-Image Generation. ICML 2021: 8821-8831

阿里灵杰回顾

用户头像

还未添加个人签名 2020-10-15 加入

分享阿里云计算平台的大数据和AI方向的技术创新和趋势、实战案例、经验总结。

评论

发布
暂无评论
当大火的文图生成模型遇见知识图谱,AI画像趋近于真实世界_深度学习_阿里云大数据AI技术_InfoQ写作社区