Neural Architecture Search with Reinforecement Learning
tags: Reinforcement Learning, CNN, RNN, AutoML
前言
CNN和RNN是目前主流的CNN框架,这些网络均是由人为手动设计,然而这些设计是非常困难以及依靠经验的。作者在这篇文章中提出了使用强化学习(Reinforcement Learning)学习一个CNN(后面简称NAS-CNN)或者一个RNN cell(后面简称NAS-RNN),并通过最大化网络在验证集上的精度期望来优化网络,在CIFAR-10数据集上,NAS-CNN的错误率已经逼近当时最好的DenseNet{{"huang2017densely"|cite}},在TreeBank数据集上,NAS-RNN要优于LSTM。
1. 背景介绍
文章提出了Neural Architecture Search(NAS),算法的主要目的是使用强化学习寻找最优网络,包括一个图像分类网络的卷积部分(表示层)和RNN的一个类似于LSTM的cell。由于现在的神经网络一般采用堆叠block的方式搭建而成,这种堆叠的超参数可以通过一个序列来表示。而这种序列的表示方式正是RNN所擅长的工作。
所以,NAS会使用一个RNN构成的控制器(controller)以概率随机采样一个网络结构,接着在CIFAR-10上训练这个网络并得到其在验证集上的精度,然后在使用更新控制器的参数,如此循环执行直到模型收敛,如图1所示。
2. NAS详细介绍
2.1 NAS-CNN
首先我们考虑最简单的CNN,即只有卷积层构成。那么这种类型的网络是很容易用控制器来表示的。即将控制器分成段,每一段由若干个输出,每个输出表示CNN的一个超参数,例如Filter的高,Filter的宽,横向步长,纵向步长以及Filter的数量,如图2所示。
了解了控制器的结构以及控制器如何生成一个卷积网络,唯一剩下的也是最终要的便是如何更新控制器的参数。
控制器每生成一个网络可以看做一个action,记做,其中是要预测的超参数的数量。当模型收敛时其在验证集上的精度是。我们使用来作为强化学习的奖励信号,也就是说通过调整参数来最大化的期望,表示为:
由于是不可导的,所以我们需要一种可以更新的策略,NAS中采用的是Williams等人提出的REINFORCE rule :
上式近似等价于:
其中是每个batch中网络的数量。
上式是梯度的无偏估计,但是往往方差比较大,为了减小方差算法中使用的是下面的更新值:
基线b是以前架构精度的指数移动平均值。
上面得到的控制器的搜索空间是不包含跳跃连接(skip connection)的,所以不能产生类似于ResNet或者Inception之类的网络。NAS-CNN是通过在上面的控制器中添加注意力机制来添加跳跃连接的,如图3。
在第层,我们添加个anchor来确定是否需要在该层和之前的某一层添加跳跃连接,这个anchor是通过两层的隐节点状态和sigmoid激活函数来完成判断的,具体的讲:
其中是第层隐层节点的状态,。,和是可学习的参数,跳跃连接的添加并不会影响更新策略。
由于添加了跳跃连接,而由训练得到的参数可能会产生许多问题,例如某个层和其它所有层都没有产生连接等等,所以有几个问题我们需要注意:
如果一个层和其之前的所有层都没有跳跃连接,那么这层将作为输入层;
如果一个层和其之后的所有层都没有跳跃连接,那么这层将作为输出层,并和所有输出层拼接之后作为分类器的输入;
如果输入层拼接了多个尺寸的输入,则通过将小尺寸输入加值为0的padding的方式进行尺寸统一。
除了卷积和跳跃连接,例如池化,BN,Dropout等策略也可以通过相同的方式添加到控制器中,只不过这时候需要引入更多的策略相关参数了。
经过训练之后,在CIFAR-10上得到的卷积网络如图4所示。
从图4我们可以发现NAS-CNN和DenseNet有很多相通的地方:
都是密集连接;
Feature Map的个数都比较少;
Feature Map之间都是采用拼接的方式进行连接。
在生成NAS-CNN的实验中,使用的是CIFAR-10数据集。网络中加入了BN和跳跃连接。卷积核的高的范围是,宽的范围也是,个数的范围是。步长分为固定为1和在中两种情况。控制器使用的是含有35个隐层节点的LSTM。
2.2 NAS-RNN
在这篇文章中,作者采用强化学习的方法同样生成了RNN中类似于LSTM或者GRU的一个Cell。控制器的参数更新方法和1.2节类似,这里我们主要介绍如何使用一个RNN控制器来描述一个RNN cell。
传统RNN的的输入是和,输出是,计算方式是。LSTM的输入是,以及单元状态,输出是和,LSTM的处理可以看做一个将,和作为叶子节点的树结构,如图5所示。
和LSTM一样,NAS-RNN也需要输入一个并输出一个,并在控制器的最后两个单元中控制如何使用以及如何计算。
如图6所示,在这个树结构中有两个叶子节点和一个中间节点,这种两个叶子节点的情况简称为base2,而图4的LSTM则是base4。叶子节点的索引是0,1,中间节点的索引是2,如图6左侧部分。也就是说控制器需要预测3个block,每个block包含一个操作(加,点乘等)和一个激活函数(ReLU,sigmoid,tanh等)。在3个block之后接的是一个Cell inject,用于控制的使用,最后是一个Cell indices,确定哪些树用于计算。
详细分析一下图6:
控制器为索引为0的树预测的操作和激活函数分别是Add和tanh,意味着;
控制器为索引为1的树预测的操作和激活函数分别是ElemMult和ReLU,意味着;
控制器为Cell Indices的第二个元素的预测值为0,Cell Inject的预测值是add和ReLU,意味着值需要更新为,注意这里不需要额外的参数。
控制器为索引为2的树预测的操作和激活函数分别是ElemMult和Sigmoid,意味着,因为是最大的树的索引,所以;
控制器为Cell Indices的第一个元素的预测值是1,意思是要使用索引为1的树在使用激活函数的值,即。
上面例子是使用“base 2”的超参作为例子进行讲解的,在实际中使用的是base 8,得到图7两个RNN单元。左侧是不包含max和sin的搜索空间,右侧是包含max和sin的搜索空间(控制器并没有选择sin)。
在生成NAS-RNN的实验中,使用的是Penn TreeBank数据集。操作的范围是[add, elem_mult],激活函数的范围是[identity,tanh,sigmoid,relu]。
2. 总结
在如何使用强化学习方面,谷歌一直是领头羊,除了他们具有很多机构难以匹敌的硬件资源之外,更重要的是他们拥有扎实的技术积累。本文开创性的使用了强化学习进行模型结构的探索,提出了NAS-CNN和NAS-RNN两个架构,两个算法的共同点都是使用一个RNN作为控制器来描述生成的网络架构,并使用生成架构在验证集上的表现并结合强化学习算法来训练控制器的参数。
本文的创新性可以打满分,不止是其算法足够新颖,更重要的是他们开辟的使用强化学习来学习网络架构可能在未来几年引网络模型自动生成的方向,尤其是在硬件资源不再那么昂贵的时候。文章的探讨还比较基础,留下了大量的待开发空间为科研工作者所探索,期待未来几年出现更高效,更精确的模型的提出。
最后更新于