使用Rust编写游玩贪吃蛇的神经网络(一)
本文最后更新于 2024年8月6日 凌晨
这是第一篇相关博文,我们就先聊聊网络部分的实现吧!这部分的内容相对而言比较简单,我也比较好讲明白~
神经网络的结构
简单来说,神经网络就是一个控制系统,其由若干组不同形状、元素组成的矩阵,以及不同的非线性算子构成。我们说的矩阵,一般都在二维,可以用二维列表的结构保存矩阵的元素。但是在神经网络中,矩阵的维度可以上升到成千上百,目的就是从多个维度对原始的输入数据进行特征提取的操作。
非线性算子则是引入非线性。试想一下,如果没有非线性算子,若干个矩阵的乘法操作,那不是就可以等加成一个矩阵就完事了吗?在XX书中,有详细的证明过程,我就不在这里展开了。
我们总结一下,典型的神经网络,其元素就是数字。这些数字存储在数组形的数据结构里,在输入来临时,会与其进行计算。
Rust中的神经网络实现
该项目的神经网络是一个前馈神经网络,也叫多层感知机(multilayer perceptron, MLP),是典型的深度学习模型。前馈网络的目标是近似某个函数$f^{\star}$。例如,对于分类器,$y=f^{\star}(x)$将输入 x 映射到一个类别 y。前馈网络定义了一个映射 $y = f(x;\theta)$,并且学习参数 $\theta$ 的值, 使它能够得到最佳的函数近似。相关逻辑代码存放在src\nn.rs
里,网络的配置则在configs.rs
文件中的pub const NN_ARCH: [usize; 4] = [24, 16, 8, 4];
里面。这里我贴一下作者的实现设计:
1 |
|
作者设计了三种结构:节点(Nodes),层(Layer) 以及 网络(Net)。我用一张图来解释一下吧!
如图所示,整个图里的非蓝框部分的内容都称为Net
,而每一个纵列则是一个Layer
,每一个Layer
里面的圆圈则是Node
。Node
则由一个向量以及一个偏置bias
组成,这对应了MLP里面的运算流程。一个N_0
维的向量(可看作一个N_0x1
的矩阵)与网络层中的每一个节点进行相乘,并将得到的运算结果附加一个偏置bias
,即:一个节点对应的就是一个1xN_1
的矩阵以及一个偏置。本例的运算的过程如下:在第一层最上边的节点,每一个输入的元素都会与该节点所对应的矩阵进行相乘,得到一个N_1x1
的矩阵,然后对这个矩阵的每一个值都添加一个偏置bias_1
。接下来就是把第一层输出的N_1
个结果当作输入,和第二层中的节点进行计算重复上面的流程直至输出结果。
接下来就是网络的输入:他的推特上是这么解释的,但我感觉好像说得不清不楚的…
我尝试理解一下吧:前半部分的8x2,8代表的是蛇头部能看见的八个方向,后面的2代表的是两种状态食物(Food)和边界(Solid),表示距离当前头部的若干个格子中是否会有边界或者食物;后半部分的4x2,则是头部和尾部的运动方向。注意:作者的方向划分是这样子的:
在这个状态时,对应的输入信息如下:
此时,蛇头的左前方(TL)、左方(LF)左后方(BL),下方(BT),右下方(BR)和左下方(BL)都有障碍物,故均为Solid,Top方向正好有Food所以TP为F被激活,而R相关的部分由于若干个格子风平浪静,因此均没有被激活。头部此时的运动方向如果保持不变的话,将会继续保持向左方(L)前进,而尾部也依然保持和同一方向故也为L,因此HE-L和TA-L被激活。
说完了输入,我们来看看输出。这个网络的输出居然是蛇不能前进的方向!太神秘了,一般我们都会用端到端的方式输出蛇前进的方向。但根据我的研究,似乎确实最后输出的是不可以前进的方向。这一点需要后续在代码里看看,总之神经网络结构方面的代码就在这里结束啦!什么?你说网络实现的方法你不懂?没关系!我们会在下一篇中进行讲解~