Tags: OCR, STN
前言
自LeNet-5的结构被提出之后,其“卷积+池化+全连接”的结构被广泛的应用到各处,但是这并不代表其结构没有了优化空间。传统的池化方式(Max Pooling/Average Pooling)所带来卷积网络的位移不变性和旋转不变性只是局部的和固定的。而且池化并不擅长处理其它形式的仿射变换。
Spatial Transformer Network(STN)的提出动机源于对池化的改进,即与其让网络抽象的学习位移不变性和旋转不变性,不如设计一个显示的模块,让网络线性的学习这些不变性,甚至将其范围扩展到所有仿射变换乃至非放射变换。更加通俗的将,STN可以学习一种变换,这种变换可以将进行了仿射变换的目标进行矫正。这也为什么我把STN放在了OCR这一章,因为在OCR场景中,仿射变换是一种最为常见的变化情况。
基于这个动机,作者设计了Spatial Transformer(ST),ST具有显示学习仿射变换的能力,并且ST是可导的,因此可以直接整合进卷积网络中进行端到端的训练,插入ST的卷积网络叫做STN。
下面根据一份STN的keras源码:https://github.com/oarriaga/spatial_transformer_networks详解STN的算法细节。
1. ST
ST由三个模块组成:
Localisation Network:该模块学习仿射变换矩阵(附件A);
Parameterised Sampling Grid:根据Localisation Network得到仿射变换矩阵,得到输出Feature Map和输入Feature Map之间的位置映射关系;
Differentiable Image Sampling:计算输出Feature Map的每个像素点的值。
STM的结构见图1:
图1:STM的框架图
ST使用的插值方法属于“后向插值”的一种,即给定输出Feature Map上的一个点Gi=(xit,yit),我们某种变化呢反向找到其在输入Feature Map中对应的位置(xis,yis),如果(xis,yis)为整数,则输出Feature Map在(xit,yit)处的值和输入Feature Map在(xit,yit)处的值相同,否则需要通过插值的方法得到输出Feature Map在(xit,yit)处的值。
说了后向插值,当然还有一种插值方式叫做前向插值,例如我们在Mask R-CNN中介绍的插值方法。
1.1 Localisation Network
Localisation Network是一个小型的卷积网络Θ=floc(U),其输入是Feature Map (U∈RW×H×C),输出是仿射矩阵Θ 的六个值。因此输出层是一个有六个节点回归器。
θ=[θ11θ21θ12θ22θ13θ23] 下面的是源码中给出的Localisation Network的结构。
locnet = Sequential()
locnet.add(MaxPooling2D(pool_size=(2,2), input_shape=input_shape))
locnet.add(Conv2D(20, (5, 5)))
locnet.add(MaxPooling2D(pool_size=(2,2)))
locnet.add(Conv2D(20, (5, 5)))
locnet.add(Flatten())
locnet.add(Dense(50))
locnet.add(Activation('relu'))
locnet.add(Dense(6, weights=weights))
1.2 Parameterised Sampling Grid
Parameterised Sampling Grid利用Localisation Network产生的Θ进行仿射变换,即由输出Feature Map上的某一位置Gi=(xit,yit)根据变换参数θ 得到输入Feature Map的某一位置(xis,yis):
(xisyis)=Tθ(Gi)=Θxityit1=[θ11θ21θ12θ22θ13θ23]xityit1 图2展示了ST中的一次仿射变换(b)和直接映射的区别。
图2: ST的仿射变换和普通卷积的直接映射
这里需要注意两点:
1. Θ可以是一个更通用的矩阵,并不局限于仿射变换,甚至不局限于6个值;
2. 映射得到的(xis,yis)一般不是整数,因此不能(xit,yit)不能使用(xis,yis)的值,而是根据它进行插值,也就是我们下一节要讲的东西。
1.3 Differentiable Image Sampling
如果(xis,yis)为一整数,那么输出Feature Map的(xit,yit)处的值便可以从输入Feature Map上直接映射过去。然而在的1.2节我们讲到,(xis,yis)往往不是整数,这时我们需要进行插值才能确定输出其值,在这个过程叫做一次插值,或者一次采样(Sampling)。插值过程可以用下式表示:
Vic=n∑Hm∑WUnmck(xis−m;Φx)k(yis−m;Φy),where∀i∈[1,...,H′W′],∀c∈[1,...,C] 在上式中,函数f()表示插值函数,本文将以双线性插值为例进行解析,Φ为f()中的参数,Unmc为输入Feature Map上点(n,m,c)处的值,Vic便是插值后输出Feature Map的(xit,yit)处的值。
H′,W′分别为输出Feature Map的高和宽。当H′=H并且W′=W时,则ST是正常的仿射变换,当H′=H/2并且W′=W/2时, 此时ST可以起到和池化类似的降采样的功能。
以双线性插值为例,插值过程即为:
Vic=n∑Hm∑WUnmcmax(0,1−∣xis−m∣)max(0,1−∣yis−m∣) 上式可以这么理解:遍历整个输入Feature Map,如果遍历到的点(n,m)距离大于1,即∣xis−m∣>1,那么max(0,1−∣xis−m∣)=0(n处同理),即只有距离(xis,yis)最近的四个点参与计算。且距离与权重成反比,也就是距离越小,权值越大,也就是双线性插值的过程,如图3。其中(xis,yis)=(1.2,2.3), U12=1,U13=2,U22=3,U23=4,则
V(xis,yis)=0.8×0.7×1+0.2×0.7×2+0.8×0.3×3+0.2×0.3×4=1.8 图3:STN中的双线性插值示例
上式中的几个值都是可偏导的:
∂Unmc∂Vic=n∑Hm∑Wmax(0,1−∣xis−m∣)max(0,1−∣yis−m∣) ∂xis∂Vic=n∑Hm∑WUnmcmax(0,1−∣yis−m∣)⎩⎨⎧01−1if∣m−xis∣>1ifm≥xisifm<xis ∂yis∂Vic=n∑Hm∑WUnmcmax(0,1−∣xis−n∣)⎩⎨⎧01−1if∣n−yis∣>1ifn≥yisifn<yis 在对θ 求导为:
∂θ∂Vic=(∂xis∂Vic⋅∂θ∂xis∂yis∂Vic⋅∂θ∂yis) ST的可导带来的好处是其可以和整个卷积网络一起端到端的训练,能够以layer的形式直接插入到卷积网络中。
2. STN
1.3节中介绍过,将ST插入到卷积网络中便得到了STN,在插入ST的时候,需要注意以下几点:
在输入图像之后接一个ST是最常见的操作,也是最容易理解的,即自动图像矫正;
理论上讲ST是可以以任意数量插入到网络中的任意位置,ST可以起到裁剪的作用,是一种高级的Attention机制。但多个ST无疑增加了网络的深度,其带来的收益价值值得讨论;
STM虽然可以起到降采样的作用,但一般不这么使用,因为基于ST的降采样产生了对其的问题;
可以在同一个卷积网络中并行使用多个ST,但是一般ST和图像中的对象是1:1的关系,因此并不是具有非常广泛的通用性。
3. STN的应用场景
3.1 并行ST
在这个场景中,输入是两张有仿射变换的MNIST的图片,然后直接输出这两个图片的数字的和(是一个19类的分类任务,不是两个10分类任务),如图3右侧图。
图4:并行ST
具体的将,给定两张图像,初始化两个ST,将两个ST分别作用于两张图片,得到四个Feature Map,将这个四个通道的图片作为FCN的输入,预测0-18间的一个整数值。图3左边的实验结果显示了两个并行ST的效果明显强于单个ST。
在鸟类分类的任务上,作者并行使用了两个和四个ST,得到了图5的实验结果:
图5:STN用于鸟类分类
在这里STN可以理解为一种Attention机制,即不同的Feature Map注意小鸟的不同部分,例如上面一排明显可以看出红色Feature Map比较注意小鸟的头部,而绿色则比较注重小鸟的身体。
3.2 STN用于半监督学习的co-localisation
在co-localisation中,给出一组图片,这些图片中包含一些公共部分,但是这组公共部分在什么地方,长什么样子我们都不知道,我们的任务时定位这些公共部分。
STN解决这个任务的方案是STN在图片m检测到的部分与在图片n中检测到的部分的相似性应该小于STN在n中随机采样的部分,如图6。
图6:STN用于半监督学习的co-localisation
Loss使用的是Hinge损失函数:
n∑Nn=m∑Mmax(0,∣∣e(InT)−e(ImT)∣∣22−∣∣e(InT)−e(Inrand)∣∣22+α) 其中InT和ImT是STN裁剪得到的图像,Inrand是随机采样的图像,e()是编码函数,α是hinge loss的margin,即裕度,相当于净赚多少。
3.3 高维ST
STN也可扩展到三维,此时的放射变换矩阵是的3行4列的,仿射变换表示为:
xisyiszis=θ11θ21θ31θ12θ22θ32θ13θ23θ33θ14θ24θ34=xityitzit1 此时Localisation Network需要回归预测12个值,插值则是使用的三线性插值。
STN的另外一个有趣的方向是通过将图像在一个维度上展开,将3维物体压缩到二维,如图7。
图7:STN用于高维映射
附件A:仿射变换矩阵
仿射变换(Affline Transformation)是一种二维坐标到二维坐标的线性变化,其保持了二维图形的平直性(straightness,即变换后直线依旧是直线,不会变成曲线)和平行性(parallelness,平行线依旧平行,不会相交)。仿射变换可以由一系列原子变换构成,其中包括:平移(Translation),缩放(Scale),翻转(Flip),旋转(Rotation)和剪切(Crop)。仿射变换可以用下面公式表示:
x′y′1=θ11θ210θ12θ220θ13θ231xy1 图8是一些常见的仿射变换的形式及其对应的仿射变换矩阵。
图8:常见仿射变换