multi-task learning for text recognition with joint CTC-attention
前言
这篇文章带来一份OCR源码(PyTorch)的梳理,代码地址为:https://github.com/bityigoss/mtl-text-recognition。这份源码使用了CTC loss和Attention Loss的多任务模型,也可以说是基于CRNN和Image Caption两个任务的多任务模型。
1. MTL详解

1.1 网络概览
MTL的网络结构的后半部分如图1所示,在它之前是一个由CNN组成的特征提取网络,最终得到的Feature Map会以列为时间片为单位输入到RNN中,也就是输入到图像的。RNN之后有两个Head,一个是CTC,另外一个是Attention Decoder,他们两个共同组成网络的损失函数。如果只考虑左侧CTC的话,那么它就是一个标准的CRNN模型。
1.2 代码梳理
1.2.1 执行脚本
要梳理MTL的代码流程,我们先要知道网络的一些超参,在源码的README中,多任务模型的调用方式如下(源文件README有误):
前面四项是用来控制读取数据的超参。5-7个比较直观,第8个--Transformation是用来控制是否使用STN,第9个--FeatureExtraction是提取图像特征的网络结构,第10个--SequenceModeling是图1中‘Shared Encoder’的结构。--Prediction是预测的时候选择图1中的CTC的分支或者是Attention Decoder分支。--mtl是选择模型的训练方式,是选择一个任务进行训练还是训练多任务模型。--without_prediction是指模型加载的方式,是否需要预测模块。
除了上面列出的,在train.py或者mtl_train.py文件中还有很多可以调整的超参,例如优化方式中涉及的学习策略,学习率;数据处理方式的图像尺寸等。
作者在这里面有个错误,如果要执行多任务模型,需要执行python mtl_train.py而不是上面给出的python train.py。因为在作者的代码中,train.py调用的是model.py,而mtl_train.py则调用的是mtl_model.py。
1.2.2 网络模型
在上面提到了执行多任务训练要调用mtl_model.py文件,下面我们来梳理一下这个文件。这个文件的第一个重点在源码的第19-24行,在这里我们确认模型是否要使用STN。STN的实现定义在./modules/transformation.py文件中。
第26-36用于选择模型的特征提取网络,卷积网络的具体细节定义在./modules/feature_extraction.py文件中:
第38-46行是网络的RNN部分,作者只给出了双向LSTM的支持
第48-53行则是给出多任务模型的CTC输出和Attention输出
Attention预测的计算方式在./modules/prediction.py文件中,它采用了Image Caption式的生成模型来进行预测,核心代码在第43-66行
我们先看训练时的分支,在预测第个时间片的输出时,我们首先对Ground Truth的前个时间片的内容进行one-hot编码,之后通过attention_cell得到当前时间片的输出。遍历的时间片的总个数超参由--batch_max_length定义的,源码中的值是25,也就是说Attention Decoder分支最大支持25个字符的识别。
解码过程和训练过程略有不同,它使用的预测的结果(最后三行)作为上个时间片的输入编码,而训练的时候使用的是GroundTruth。
AttentioCell的核心代码在prediction.py文件的81-91行,它是一个基于单向LSTM的生成器,它的输入有三个,其中prev_hidden是上一个时间片的隐层状态,batch_H是双向LSTM的输出,也就是图1中的,char_onehots是GroundTruth的前个时间片(训练时)或者之前的预测结果(测试时)的one-hot编码。
1.2.3 损失函数
介绍完MTL的网络结构,我们再回到mtl_train.py看一下网络的损失函数。如果我们在训练模型的时候选择了多任务训练,那么在文件的第98-100行我们可以得到两个任务具体的损失函数,其中CTC使用的自然是CTC损失函数,而Attention Decoder使用的则是交叉熵损失函数。
在源码的第155-165行则是使用上面定义的损失韩式来分别计算两个任务的损失,并最终采用加权的方式得到最终的损失函数。权值参数通过--ctc_weight超参设置,源码中的值是0.2。
1.2.4 测试,推理
从mtl_infer.py或者mtl_test.py中我们可以看出,在测试或者推理的时候,MTL可以选择从使用CTC的输出或者是Attention Decoder的输出。
2.总结
这份代码的质量比较高,算法的思路,代码的框架,注释,命名规则,接口设置都比较规范,也许是TensorFlow的代码看多了,这份读起来思路非常流畅,也很简单。目前OCR文字识别的算法流程也大概分成两个流派,一个是基于CTC的一维文字识别,另外一个是基于Image Caption的二维或者扭曲文字的识别,这份代码基于这两个方向构建了一个多任务模型,虽然不是很具有创新性,但是在工程上还是很有借鉴价值的。
最后更新于
这有帮助吗?