使用Rust编写游玩贪吃蛇的神经网络(一)

本文最后更新于 2024年8月6日 凌晨

Program In Run

这是第一篇相关博文,我们就先聊聊网络部分的实现吧!这部分的内容相对而言比较简单,我也比较好讲明白~

神经网络的结构

简单来说,神经网络就是一个控制系统,其由若干组不同形状、元素组成的矩阵,以及不同的非线性算子构成。我们说的矩阵,一般都在二维,可以用二维列表的结构保存矩阵的元素。但是在神经网络中,矩阵的维度可以上升到成千上百,目的就是从多个维度对原始的输入数据进行特征提取的操作。

非线性算子则是引入非线性。试想一下,如果没有非线性算子,若干个矩阵的乘法操作,那不是就可以等加成一个矩阵就完事了吗?在XX书中,有详细的证明过程,我就不在这里展开了。

我们总结一下,典型的神经网络,其元素就是数字。这些数字存储在数组形的数据结构里,在输入来临时,会与其进行计算。

Rust中的神经网络实现

该项目的神经网络是一个前馈神经网络,也叫多层感知机(multilayer perceptron, MLP),是典型的深度学习模型。前馈网络的目标是近似某个函数$f^{\star}$。例如,对于分类器,$y=f^{\star}(x)$将输入 x 映射到一个类别 y。前馈网络定义了一个映射 $y = f(x;\theta)$,并且学习参数 $\theta$ 的值, 使它能够得到最佳的函数近似。相关逻辑代码存放在src\nn.rs里,网络的配置则在configs.rs文件中的pub const NN_ARCH: [usize; 4] = [24, 16, 8, 4];里面。这里我贴一下作者的实现设计:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
//! A simple Feed-Forward Neural Network
//!
//! It can't do backpropagation
//! It can only be used for neuro-evolution


#[derive(Clone, Serialize, Deserialize)]
pub struct Net {
n_inputs: usize,
layers: Vec<Layer>,
}

#[derive(Clone, Serialize, Deserialize)]
struct Layer {
nodes: Vec<Node>,
}

#[derive(Clone, Serialize, Deserialize)]
struct Node {
weights: Vec<f64>,
bias: f64,
}

impl Net {
pub fn new(layer_sizes: &[usize]) -> Self {
...
}

pub fn merge(&self, other: &Net) -> Self {
...
}

pub fn predict(&self, inputs: Vec<f64>) -> Vec<Vec<f64>> {
...
}

pub fn mutate(&mut self) {
...
}

pub fn save(&self) {
...
}

pub fn load() -> Self {
...
}

// This is for visualization
pub fn get_bias(&self, layer_idx: usize) -> Vec<f64> {
...
}
}

impl Layer {
fn new(layer_size: usize, prev_layer_size: usize) -> Self {
...
}

fn merge(&self, other: &Layer) -> Self {
...
}

fn predict(&self, inputs: &Vec<f64>) -> Vec<f64> {
...
}

fn mutate(&mut self) {
...
}
}

作者设计了三种结构:节点(Nodes),层(Layer) 以及 网络(Net)。我用一张图来解释一下吧!

神经网络结构图

如图所示,整个图里的非蓝框部分的内容都称为Net,而每一个纵列则是一个Layer,每一个Layer里面的圆圈则是NodeNode则由一个向量以及一个偏置bias组成,这对应了MLP里面的运算流程。一个N_0维的向量(可看作一个N_0x1的矩阵)与网络层中的每一个节点进行相乘,并将得到的运算结果附加一个偏置bias,即:一个节点对应的就是一个1xN_1的矩阵以及一个偏置。本例的运算的过程如下:在第一层最上边的节点,每一个输入的元素都会与该节点所对应的矩阵进行相乘,得到一个N_1x1的矩阵,然后对这个矩阵的每一个值都添加一个偏置bias_1。接下来就是把第一层输出的N_1个结果当作输入,和第二层中的节点进行计算重复上面的流程直至输出结果。

接下来就是网络的输入:他的推特上是这么解释的,但我感觉好像说得不清不楚的…

作者推特上关于输入的解释

我尝试理解一下吧:前半部分的8x2,8代表的是蛇头部能看见的八个方向,后面的2代表的是两种状态食物(Food)和边界(Solid),表示距离当前头部的若干个格子中是否会有边界或者食物;后半部分的4x2,则是头部和尾部的运动方向。注意:作者的方向划分是这样子的:

输入示例与方向说明

在这个状态时,对应的输入信息如下:

对应输入状态图

此时,蛇头的左前方(TL)、左方(LF)左后方(BL),下方(BT),右下方(BR)和左下方(BL)都有障碍物,故均为Solid,Top方向正好有Food所以TP为F被激活,而R相关的部分由于若干个格子风平浪静,因此均没有被激活。头部此时的运动方向如果保持不变的话,将会继续保持向左方(L)前进,而尾部也依然保持和同一方向故也为L,因此HE-L和TA-L被激活。

说完了输入,我们来看看输出。这个网络的输出居然是蛇不能前进的方向!太神秘了,一般我们都会用端到端的方式输出蛇前进的方向。但根据我的研究,似乎确实最后输出的是不可以前进的方向。这一点需要后续在代码里看看,总之神经网络结构方面的代码就在这里结束啦!什么?你说网络实现的方法你不懂?没关系!我们会在下一篇中进行讲解~


使用Rust编写游玩贪吃蛇的神经网络(一)
https://cybercolyce.cn/2024/07/08/Rust-Snake-AI-NN-related/
作者
L4k3d22
发布于
2024年7月8日
许可协议