写点什么

【详细注释 + 流程讲解】基于深度学习的文本分类 TextCNN

作者:阿里云天池
  • 2024-04-02
    浙江
  • 本文字数:3237 字

    阅读完需:约 11 分钟

​前言


这篇文章用于记录阿里天池 NLP 入门赛,详细讲解了整个数据处理流程,以及如何从零构建一个模型,适合新手入门。


赛题以新闻数据为赛题数据,数据集报名后可见并可下载。赛题数据为新闻文本,并按照字符级别进行匿名处理。整合划分出 14 个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐的文本数据。实质上是一个 14 分类问题。


赛题数据由以下几个部分构成:训练集 20w 条样本,测试集 A 包括 5w 条样本,测试集 B 包括 5w 条样本。


比赛地址:零基础入门 NLP - 新闻文本分类_学习赛_天池大赛-阿里云天池的赛制


数据可以通过上面的链接下载。


其中还用到了训练好的词向量文件。


词向量下载链接: 百度网盘 请输入提取码 提取码: qbpr


这篇文章中使用的模型主要是 CNN + LSTM + Attention,主要学习的是数据处理的完整流程,以及模型构建的完整流程。虽然还没有使用 Bert 等方案,不过如果看完了这篇文章,理解了整个流程之后,即使你想要使用其他模型来处理,也能更快实现。


1.为什么写篇文章


首先,这篇文章的代码全部都来源于 Datawhale 提供的开源代码,我添加了自己的笔记,帮助新手更好地理解这个代码。


1.1 Datawhale 提供的代码有哪些需要改进?

Datawhale 提供的代码里包含了数据处理,以及从 0 到 1 模型建立的完整流程。但是和前面提供的 basesline 的都不太一样,它包含了非常多数据处理的细节,模型也是由 3 个部分构成,所以看起来难度陡然上升。


其次,代码里的注释非常少,也没有讲解整个数据处理和网络的整体流程。这些对于新手来说,增加了理解的门槛。 在数据竞赛方面,我也是一个新人,花了一天的时间,仔细研究数据在一种每一个步骤的转化,对于一些难以理解的代码,在群里询问之后,也得到了 Datawhale 成员的热心解答。最终才明白了全部的代码。


1.2 我做了什么改进?

所以,为了减少对于新手的阅读难度,我添加了一些内容。


2.数据处理


2.1 数据拆分为 10 份

数据首先会经过 all_data2fold 函数,这个函数的作用是把原始的 DataFrame 数据,转换为一个 list,有 10 个元素,表示交叉验证里的 10 份,每个元素是 dict,每个 dict 包括 label 和 text。


首先根据 label 来划分数据行所在 index, 生成 label2id。


label2id 是一个 dict,key 为 label,value 是一个 list,存储的是该类对应的 index。


然后根据label2id,把每一类别的数据,划分到 10 份数据中。


​最终得到的数据fold_data是一个list,有 10 个元素,每个元素是 dict,包括 labeltext的列表:[{labels:textx}, {labels:textx}. . .]


最后,把前 9 份数据作为训练集 train_data,最后一份数据作为验证集 dev_data,并读取测试集 test_data。


2.2 定义并创建 Vacab


Vocab 的作用是:


  • 创建 词 和 index 对应的字典,这里包括 2 份字典,分别是:_id2word 和 _id2extword

  • 其中 _id2word 是从新闻得到的, 把词频小于 5 的词替换为了 UNK。对应到模型输入的 batch_inputs1

  • _id2extword 是从 word2vec.txt 中得到的,有 5976 个词。对应到模型输入的 batch_inputs2

  • 后面会有两个 embedding 层,其中 _id2word 对应的 embedding 是可学习的,_id2extword 对应的 embedding 是从文件中加载的,是固定的。

  • 创建 label 和 index 对应的字典。

  • 上面这些字典,都是基于train_data创建的。


3.模型


3.1 把文章分割为句子

上上一步得到的 3 个数据,都是一个 list,list 里的每个元素是 dict,每个 dict 包括 label 和 text。这 3 个数据会经过 get_examples 函数。 get_examples 函数里,会调用 sentence_split 函数,把每一篇文章分割成为句子。


然后,根据 vocab,把 word 转换为对应的索引,这里使用了 2 个字典,转换为 2 份索引,分别是:word_ids 和 extword_ids。最后返回的数据是一个 list,每个元素是一个 tuple: (label, 句子数量,doc)。其中 doc 又是一个 list,每个 元素是一个 tuple: (句子长度,word_ids, extword_ids)。


​在迭代训练时,调用 data_iter 函数,生成每一批的 batch_data。在 data_iter 函数里,会调用 batch_slice 函数生成每一个 batch。拿到 batch_data 后,每个数据的格式仍然是上图中所示的格式,下面,调用 batch2tensor 函数。


3.2 生成训练数据

batch2tensor 函数最后返回的数据是:(batch_inputs1, batch_inputs2, batch_masks), batch_labels。形状都是(batch_size, doc_len, sent_len)。doc_len 表示每篇新闻有几句话,sent_len 表示每句话有多少个单词。


batch_masks 在有单词的位置,值为 1,其他地方为 0,用于后面计算 Attention,把那些没有单词的位置的 attention 改为 0。


batch_inputs1, batch_inputs2, batch_masks,形状是(batch_size, doc_len, sent_len),转换为(batch_size * doc_len, sent_len)。


3.3 网络部分

下面,终于来到网络部分。模型结构图如下:


3.3.1 WordCNNEncoderWordCNNEncoder 网络结构示意图如下:​


  1. Embedding

batch_inputs1, batch_inputs2 都输入到 WordCNNEncoder。WordCNNEncoder 包括两个 embedding 层,分别对应 batch_inputs1,embedding 层是可学习的,得到 word_embed;batch_inputs2,读取的是外部训练好的词向量,因此是不可学习的,得到 extword_embed。所以会分别得到两个词向量,将 2 个词向量相加,得到最终的词向量 batch_embed,形状是(batch_size * doc_len, sent_len, 100),然后添加一个维度,变为(batch_size * doc_len, 1, sent_len, 100),对应 Pytorch 里图像的(B, C, H, W)。


  1. CNN

然后,分别定义 3 个卷积核,output channel 都是 100 维。


第一个卷积核大小为[2,100],得到的输出是(batch_size * doc_len, 100, sent_len-2+1, 1),定义一个池化层大小为[sent_len-2+1, 1],最终得到输出经过 squeeze()的形状是(batch_size * doc_len, 100)。


同理,第 2 个卷积核大小为[3,100],第 3 个卷积核大小为[4,100]。卷积+池化得到的输出形状也是(batch_size * doc_len, 100)。


最后,将这 3 个向量在第 2 个维度上做拼接,得到输出的形状是(batch_size * doc_len, 300)。


3.3.2 shape 转换

把上一步得到的数据的形状,转换为(batch_size , doc_len, 300)名字是 sent_reps。然后,对 mask 进行处理。


batch_masks 的形状是(batch_size , doc_len, 300),表示单词的 mask,经过 sent_masks = batch_masks.bool().any(2).float()得到句子的 mask。含义是:在最后一个维度,判断是否有单词,只要有 1 个单词,那么整句话的 mask 就是 1,sent_masks 的维度是:(batch_size , doc_len)。


3.3.3 SentEncoder

SentEncoder 网络结构示意图如下:


​SentEncoder 包含了 2 层的双向 LSTM,输入数据 sent_reps 的形状是(batch_size , doc_len, 300),LSTM 的 hidden_size 为 256,由于是双向的,经过 LSTM 后的数据维度是(batch_size , doc_len, 512),然后和 mask 按位置相乘,把没有单词的句子的位置改为 0,最后输出的数据 sent_hiddens,维度依然是(batch_size , doc_len, 512)。


3.3.4 Attention

接着,经过 Attention。Attention 的输入是 sent_hiddens 和 sent_masks。在 Attention 里,sent_hiddens 首先经过线性变化得到 key,维度不变,依然是(batch_size , doc_len, 512)。


然后 key 和 query 相乘,得到 outputs。query 的维度是 512,因此 output 的维度是(batch_size , doc_len),这个就是我们需要的 attention,表示分配到每个句子的权重。下一步需要对这个 attetion 做 softmax,并使用 sent_masks,把没有单词的句子的权重置为-1e32,得到 masked_attn_scores。


最后把 masked_attn_scores 和 key 相乘,得到 batch_outputs,形状是(batch_size, 512)。


3.3.5 FC

最后经过 FC 层,得到分类概率的向量。


4.完整代码+注释


4.1 数据处理导入包

import random​import numpy as npimport torchimport logginglogging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(levelname)s: %(message)s')


 查看本文全部内容,欢迎访问天池技术圈官方地址:【详细注释+流程讲解】基于深度学习的文本分类 TextCNN_天池技术圈-阿里云天池



用户头像

还未添加个人签名 2024-03-12 加入

还未添加个人简介

评论

发布
暂无评论
【详细注释+流程讲解】基于深度学习的文本分类 TextCNN_机器学习_阿里云天池_InfoQ写作社区