# Spatial Transform Networks

Tags: OCR, STN

## 前言

自LeNet-5的结构被提出之后，其“卷积+池化+全连接”的结构被广泛的应用到各处，但是这并不代表其结构没有了优化空间。传统的池化方式（Max Pooling/Average Pooling）所带来卷积网络的位移不变性和旋转不变性只是局部的和固定的。而且池化并不擅长处理其它形式的仿射变换。

Spatial Transformer Network（STN）的提出动机源于对池化的改进，即与其让网络抽象的学习位移不变性和旋转不变性，不如设计一个显示的模块，让网络线性的学习这些不变性，甚至将其范围扩展到所有仿射变换乃至非放射变换。更加通俗的将，STN可以学习一种变换，这种变换可以将进行了仿射变换的目标进行矫正。这也为什么我把STN放在了OCR这一章，因为在OCR场景中，仿射变换是一种最为常见的变化情况。

基于这个动机，作者设计了Spatial Transformer（ST），ST具有显示学习仿射变换的能力，并且ST是**可导**的，因此可以直接整合进卷积网络中进行端到端的训练，插入ST的卷积网络叫做STN。

下面根据一份STN的keras源码：<https://github.com/oarriaga/spatial_transformer_networks>详解STN的算法细节。

## 1. ST

ST由三个模块组成：

1. Localisation Network：该模块学习仿射变换矩阵（附件A）；
2. Parameterised Sampling Grid：根据Localisation Network得到仿射变换矩阵，得到输出Feature Map和输入Feature Map之间的位置映射关系；
3. Differentiable Image Sampling：计算输出Feature Map的每个像素点的值。

STM的结构见图1：

#### 图1：STM的框架图

![](https://4188449087-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-M9CzxVHWUyXHERFQI87%2Fsync%2Fd595e2ecc7b9149b244bba9cab636bfd68fccf66.png?generation=1591520340108759\&alt=media)

ST使用的插值方法属于“后向插值”的一种，即给定输出Feature Map上的一个点$$G\_i = (x^t\_i, y^t\_i)$$，我们某种变化呢反向找到其在输入Feature Map中对应的位置$$(x^s\_i, y^s\_i)$$，如果$$(x^s\_i, y^s\_i)$$为整数，则输出Feature Map在$$(x^t\_i, y^t\_i)$$处的值和输入Feature Map在$$(x^t\_i, y^t\_i)$$处的值相同，否则需要通过插值的方法得到输出Feature Map在$$(x^t\_i, y^t\_i)$$处的值。

说了后向插值，当然还有一种插值方式叫做前向插值，例如我们在[Mask R-CNN](https://senliuy.gitbooks.io/advanced-deep-learning/content/chapter1/mask-r-cnn.html)中介绍的插值方法。

### 1.1 Localisation Network

Localisation Network是一个小型的卷积网络$$\Theta = f\_{loc}(U)$$，其输入是Feature Map （$$U\in R^{W\times H\times C}$$），输出是仿射矩阵$$\Theta$$ 的六个值。因此输出层是一个有六个节点回归器。

$$
\theta =
\left\[
\begin{matrix}
\theta\_{11} & \theta\_{12} & \theta\_{13} \\
\theta\_{21} & \theta\_{22} & \theta\_{23}
\end{matrix}
\right]
$$

下面的是源码中给出的Localisation Network的结构。

```python
locnet = Sequential()
locnet.add(MaxPooling2D(pool_size=(2,2), input_shape=input_shape))
locnet.add(Conv2D(20, (5, 5)))
locnet.add(MaxPooling2D(pool_size=(2,2)))
locnet.add(Conv2D(20, (5, 5)))

locnet.add(Flatten())
locnet.add(Dense(50))
locnet.add(Activation('relu'))
locnet.add(Dense(6, weights=weights))
```

### 1.2 Parameterised Sampling Grid

Parameterised Sampling Grid利用Localisation Network产生的$$\Theta$$进行仿射变换，即由输出Feature Map上的某一位置$$G\_i = (x^t\_i, y^t\_i)$$根据变换参数$$\theta$$ 得到输入Feature Map的某一位置$$(x^s\_i, y^s\_i)$$：

$$
\left(\begin{matrix}x\_i^s \y\_i^s\end{matrix} \right)
\= \mathcal{T}*\theta(G\_i)
\= \Theta\left(\begin{matrix}x\_i^t\y\_i^t\1\end{matrix}\right)
\= \left\[\begin{matrix}\theta*{11} & \theta\_{12} & \theta\_{13} \\
\theta\_{21} & \theta\_{22} & \theta\_{23}\end{matrix}\right]
\left(\begin{matrix}x\_i^t\y\_i^t\1\end{matrix}\right)
$$

图2展示了ST中的一次仿射变换（b）和直接映射的区别。

#### 图2: ST的仿射变换和普通卷积的直接映射

![](https://4188449087-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-M9CzxVHWUyXHERFQI87%2Fsync%2F5ef9c3bcb36f756a9dc9949b878f1d8dde7a96a0.png?generation=1591520339581566\&alt=media)

这里需要注意两点：\
1\. $$\Theta$$可以是一个更通用的矩阵，并不局限于仿射变换，甚至不局限于6个值；\
2\. 映射得到的$$(x^s\_i, y^s\_i)$$一般不是整数，因此不能$$(x^t\_i, y^t\_i)$$不能使用$$(x^s\_i, y^s\_i)$$的值，而是根据它进行插值，也就是我们下一节要讲的东西。

### 1.3 Differentiable Image Sampling

如果$$(x^s\_i, y^s\_i)$$为一整数，那么输出Feature Map的$$(x^t\_i, y^t\_i)$$处的值便可以从输入Feature Map上直接映射过去。然而在的1.2节我们讲到，$$(x^s\_i, y^s\_i)$$往往不是整数，这时我们需要进行插值才能确定输出其值，在这个过程叫做一次插值，或者一次采样（Sampling）。插值过程可以用下式表示：

$$
V\_{i}^c = \sum^H\_n \sum^W\_m U^c\_{nm} k(x\_i^s-m;\Phi\_x) k(y\_i^s -m; \Phi\_y)
,\quad where\quad \forall i\in\[1,...,H'W'],\forall c\in\[1,...,C]
$$

在上式中，函数$$f()$$表示插值函数，本文将以双线性插值为例进行解析，$$\Phi$$为$$f()$$中的参数，$$U^c\_{nm}$$为输入Feature Map上点$$(n, m, c)$$处的值，$$V\_i^c$$便是插值后输出Feature Map的$$(x^t\_i, y^t\_i)$$处的值。

$$H',W'$$分别为输出Feature Map的高和宽。当$$H'=H$$并且$$W'=W$$时，则ST是正常的仿射变换，当$$H'=H/2$$并且$$W'=W/2$$时, 此时ST可以起到和池化类似的降采样的功能。

以双线性插值为例，插值过程即为：

$$
V\_{i}^c = \sum^H\_n \sum^W\_m U^c\_{nm} max(0, 1 - |x\_i^s-m|) max(0,1-|y\_i^s -m|)
$$

上式可以这么理解：遍历整个输入Feature Map，如果遍历到的点$$(n,m)$$距离大于1，即$$|x\_i^s-m|>1$$，那么$$max(0, 1 - |x\_i^s-m|)=0$$（n处同理），即只有距离$$(x^s\_i, y^s\_i)$$最近的四个点参与计算。且距离与权重成反比，也就是距离越小，权值越大，也就是双线性插值的过程，如图3。其中$$(x^s\_i, y^s\_i)=(1.2, 2.3)$$, $$U\_{12} = 1, U\_{13} = 2, U\_{22} = 3, U\_{23}=4$$，则

$$
V\_{(x\_i^s, y\_i^s) } =
0.8 \times 0.7 \times 1

* 0.2 \times 0.7 \times 2
* 0.8 \times 0.3 \times 3
* 0.2 \times 0.3 \times 4 = 1.8
  $$

#### 图3：STN中的双线性插值示例

![](https://4188449087-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-M9CzxVHWUyXHERFQI87%2Fsync%2F87902d76f6660ce342094c3cc51109034c025e7b.png?generation=1591520339639999\&alt=media)

上式中的几个值都是可偏导的:

$$
\frac{\partial V\_i^c}{\partial U\_{nm}^c} = \sum^H\_n \sum^W\_m max(0, 1 - |x\_i^s-m|) max(0,1-|y\_i^s -m|)
$$

$$
\frac{\partial V\_i^c}{\partial x\_{i}^s} = \sum^H\_n \sum^W\_m U^c\_{nm} max(0,1-|y\_i^s -m|) \left{
\begin{array}{}
0 & \text{if};|m-x\_i^s|>1\\
1 & \text{if};m\geq x\_i^s\\
-1 & \text{if};m< x\_i^s
\end{array}
\right.
$$

$$
\frac{\partial V\_i^c}{\partial y\_{i}^s} = \sum^H\_n \sum^W\_m U^c\_{nm} max(0,1-|x\_i^s -n|) \left{
\begin{array}{}
0 & \text{if};|n-y\_i^s|>1\\
1 & \text{if};n\geq y\_i^s\\
-1 & \text{if};n< y\_i^s
\end{array}
\right.
$$

在对$$\theta$$ 求导为:

$$
\frac{\partial V\_i^c}{\partial \theta} =
\left(
\begin{matrix}
\frac{\partial V\_i^c}{\partial x\_i^s}\cdot\frac{\partial x\_i^s}{\partial \theta} \\
\frac{\partial V\_i^c}{\partial y\_i^s}\cdot\frac{\partial y\_i^s}{\partial \theta}
\end{matrix}
\right)
$$

ST的可导带来的好处是其可以和整个卷积网络一起端到端的训练，能够以layer的形式直接插入到卷积网络中。

## 2. STN

1.3节中介绍过，将ST插入到卷积网络中便得到了STN，在插入ST的时候，需要注意以下几点：

1. 在输入图像之后接一个ST是最常见的操作，也是最容易理解的，即自动图像矫正；
2. 理论上讲ST是可以以任意数量插入到网络中的任意位置，ST可以起到裁剪的作用，是一种高级的Attention机制。但多个ST无疑增加了网络的深度，其带来的收益价值值得讨论；
3. STM虽然可以起到降采样的作用，但一般不这么使用，因为基于ST的降采样产生了对其的问题；
4. 可以在同一个卷积网络中并行使用多个ST，但是一般ST和图像中的对象是$$1:1$$的关系，因此并不是具有非常广泛的通用性。

## 3. STN的应用场景

### 3.1 并行ST

在这个场景中，输入是两张有仿射变换的MNIST的图片，然后直接输出这两个图片的数字的和（是一个19类的分类任务，不是两个10分类任务），如图3右侧图。

#### 图4：并行ST

![](https://4188449087-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-M9CzxVHWUyXHERFQI87%2Fsync%2F3fc2ab03d57d06f5762688456f8f23245b1e0c7f.png?generation=1591520339954330\&alt=media)

具体的将，给定两张图像，初始化两个ST，将两个ST分别作用于两张图片，得到四个Feature Map，将这个四个通道的图片作为FCN的输入，预测0-18间的一个整数值。图3左边的实验结果显示了两个并行ST的效果明显强于单个ST。

在鸟类分类的任务上，作者并行使用了两个和四个ST，得到了图5的实验结果：

#### 图5：STN用于鸟类分类

![](https://4188449087-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-M9CzxVHWUyXHERFQI87%2Fsync%2F87db9d74e54e6ead9eee8826a5b82b7d98eccd2d.png?generation=1591520339418957\&alt=media)

在这里STN可以理解为一种Attention机制，即不同的Feature Map注意小鸟的不同部分，例如上面一排明显可以看出红色Feature Map比较注意小鸟的头部，而绿色则比较注重小鸟的身体。

### 3.2 STN用于半监督学习的co-localisation

在co-localisation中，给出一组图片，这些图片中包含一些公共部分，但是这组公共部分在什么地方，长什么样子我们都不知道，我们的任务时定位这些公共部分。

STN解决这个任务的方案是STN在图片m检测到的部分与在图片n中检测到的部分的相似性应该小于STN在n中随机采样的部分，如图6。

#### 图6：STN用于半监督学习的co-localisation

![](https://4188449087-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-M9CzxVHWUyXHERFQI87%2Fsync%2F44b0c7e64c35fe6080e66a046977cd57a46ed91b.png?generation=1591520339822910\&alt=media)

Loss使用的是Hinge损失函数:

$$
\sum\_n^N \sum\_{n\neq m}^M \text{max}(0, ||e(I\_n^{\mathcal{T}})-e(I\_m^{\mathcal{T}})||^2\_2 - ||e(I\_n^{\mathcal{T}})-e(I\_n^{rand})||^2\_2 +\alpha)
$$

其中$$I\_n^\mathcal{T}$$和$$I\_m^\mathcal{T}$$是STN裁剪得到的图像，$$I\_n^{rand}$$是随机采样的图像，$$e()$$是编码函数，$$\alpha$$是hinge loss的margin，即裕度，相当于净赚多少。

### 3.3 高维ST

STN也可扩展到三维，此时的放射变换矩阵是的3行4列的，仿射变换表示为:

$$
\left(
\begin{matrix}
x\_i^s \\
y\_i^s \\
z\_i^s
\end{matrix}
\right) =
\left\[
\begin{matrix}
\theta\_{11} & \theta\_{12} & \theta\_{13} & \theta\_{14} \\
\theta\_{21} & \theta\_{22} & \theta\_{23} & \theta\_{24} \\
\theta\_{31} & \theta\_{32} & \theta\_{33} & \theta\_{34}
\end{matrix}
\right] =
\left(
\begin{matrix}
x\_i^t \\
y\_i^t \\
z\_i^t \\
1
\end{matrix}
\right)
$$

此时Localisation Network需要回归预测12个值，插值则是使用的三线性插值。

STN的另外一个有趣的方向是通过将图像在一个维度上展开，将3维物体压缩到二维，如图7。

#### 图7：STN用于高维映射

![](https://4188449087-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-M9CzxVHWUyXHERFQI87%2Fsync%2Fc323ed2fc713d6040a925cdd8530447e6632a4f3.png?generation=1591520339951524\&alt=media)

## 附件A：仿射变换矩阵

仿射变换(Affline Transformation)是一种二维坐标到二维坐标的线性变化，其保持了二维图形的平直性（straightness，即变换后直线依旧是直线，不会变成曲线）和平行性（parallelness，平行线依旧平行，不会相交）。仿射变换可以由一系列原子变换构成，其中包括：平移（Translation），缩放（Scale），翻转（Flip），旋转（Rotation）和剪切（Crop）。仿射变换可以用下面公式表示：

$$
\left\[
\begin{matrix}
x'\\
y'\\
1
\end{matrix}
\right] =
\left\[
\begin{matrix}
\theta\_{11} & \theta\_{12} & \theta\_{13}\\
\theta\_{21} & \theta\_{22} & \theta\_{23}\\
0 & 0 & 1
\end{matrix}
\right]
\left\[
\begin{matrix}
x\\
y\\
1
\end{matrix}
\right]
$$

图8是一些常见的仿射变换的形式及其对应的仿射变换矩阵。

#### 图8：常见仿射变换

![](https://4188449087-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-M9CzxVHWUyXHERFQI87%2Fsync%2F71f4099da3372cbca55d9c06bd91e9dcf93497f1.png?generation=1591520339681815\&alt=media)


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://senliuy.gitbook.io/advanced-deep-learning/di-wu-zhang-ff1a-guang-xue-zi-fu-shi-bie/wen-zi-shi-bie/spatial-transform-networks.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
