深度学习代码一般有一个固定的架构,所以在阅读过程中常遵循一个固定的顺序,故我在该文章中将简单总结阅读深度学习代码的经验并以MSHTrans为例子来进行阅读。

深度学习代码常见文件及代码架构

在进入github仓库后,我们常见到多个文件,哪个文件重要、哪个文件不重要、首先阅读哪个文件、哪个文件只是简单的配置文件。初学者正是因为这些问题没有搞清楚,所以在开始学习的时候常常抓瞎,进而影响学习效率。
MSHTrans仓库是一个很好的例子,因为MSHTrans的代码内只包括最基本、最重要的文件。

  1. README.md: 项目的说明文件,类似说明书,打开仓库时最先看到的文件,也是最先阅读的文件。里面包括了代码的作用、文献、使用方式。
  2. requirements.txt: 模型使用环境内的包,使用”pip install -r requirements.txt”安装依赖包
  3. main.py: 启动模型的代码,该代码一般名称为”main.py”,当然也有其他名称,例如”go.py”。
  4. common, experimental_results, networks, scripts: 剩余文件夹主要为模型相关的具体代码和实验相关的部分,这一部分并无过多通性,故需要具体问题具体分析。但主要作用是可以从文件夹名称上看出。
    阅读顺序一般如下:README.md -> main.py(启动模型的代码) -> 具体代码文件。接下来让我们先从README开始阅读:
    README中已经提及了自己的代码架构:
1
2
3
4
5
6
├── common/               # Codes for performance evaluation and data loading
│   ├── evaluation/       # Codes for performance evaluation
├── networks/             # Codes for networks
├── scripts/              # Running demos
├── main.py               # main function
├── requirements.txt      # Requirements

同时包括了模型所使用的数据:

  • Download data from Google Drive: Download Link
  • Unzip and move data to data folder (defined in parameter --data-root)
    我们下载完数据、安装完Requirements后,发现还有个Quick Start的栏目:
  • You can run the codes with scripts in ./scripts/scripts.sh:
1
2
# Example command
bash ./scripts/scripts.sh

这是启动模型的代码,给只想使用该模型而不希望深入研究具体代码的人一个简化版启动方式,该代码使用代码行启动,将工作环境设置于项目所在的文件,然后运行以上代码即可。
我们可以进入该文件(./scripts/scripts.sh)具体研究这一步做了什么:

1
2
3
4
5
python ./main.py --dataset-id SWaT --device 0  
python ./main.py --dataset-id WADI --device 0  
python ./main.py --dataset-id SMAP --device 0  
python ./main.py --dataset-id SMD --device 0  
python ./main.py --dataset-id MSL --device 0  --stride 1

该代码使用python运行了main.py的文件,并指定了dataset-id, device, stride等参数情况,dataset-id指代模型所使用的数据集id,device指定了使用的gpu(或者cpu),stride则是模型具体的参数。
于是,我们的注意力再一次集中于main.py。

主函数(main function)

main.py内即为主函数的文件,其用于指定模型的参数并启动模型和评估模型。内部存在3种主要的内容:环境变量的设置、模块的导入、参数设置以及模型代码。

  1. 环境变量的设置: os.environ使用指定环境变量,是一些最基本的参数,如”可使用的gpu”、”训练日志的记录层级”等。sys.path.append()则用于加入main.py可以使用的环境变量,即加入main.py可以访问的函数的列表,main.py为最外层文件夹,按道理无法直接使用import fft来导入fft.py,sys.path.append(./networks)则可以直接导入。
  2. 模块的导入: 导入主函数中使用的函数。
  3. 参数设置: 参数设置是在python课程中不会教授的环节,但是这在深度学习中非常重要。参数的设置一般使用argparse模块来进行设置,这让参数设置可以用cmd,也就是上面”scripts.sh”中所见的代码。我们以其中一个比较重要的参数为例子讲解parser:这表示增加一个参数–data-root,类型为str, 默认使用D:/data/MSHTrans_data作为路径,其代表着模型所使用的测试数据的路径,默认为D:/data/MSHTrans_data下
1
parser.add_argument("--data-root", type=str, default="D:/data/MSHTrans_data", help="dataset root")
  1. 模型代码:模型代码是最核心,也是最复杂的部分,接下来将从该部分入手,进行模型代码的阅读。

核心代码

核心代码纷繁复杂并没有一个统一的框架,但是大致可以按照流程进行划分:

  1. 数据预处理代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
pp = data_preprocess.preprocessor(model_root=args["model_root"])
data_dict = pp.normalize(data_dict, method=args["normalize"])
...
window_dict = data_preprocess.generate_multi_windows(
data_dict,
entity = entity,
window_size = args["window_size"],
stride = args["stride"],
)
windows = window_dict[entity]

train_windows = windows["train_windows"]
test_windows = windows["test_windows"]

train_time_series = torch.from_numpy(data_dict[entity]["train"])[: , :].float()
test_time_series = torch.from_numpy(data_dict[entity]["test"])[: , :].float()



train_loader, _, test_loader = get_dataloaders(train_windows, test_windows, train_time_series, test_time_series, batch_size=args["batch_size"])

这部分代码用于预处理读取的数据,不多加赘述。
2. 模型初始化

1
2
3
4
model = MSHTrans(
args,
device
).to(device)
  1. 模型训练
1
model.train(args, train_loader)
  1. 模型测试
1
test_anomaly_score, loss, pred = model.predict_prob(test_loader)
  1. 保存模型与结果
1
2
3
4
5
6
7
8
9
10
store_entity(
args,
entity,
train_anomaly_score,
test_anomaly_score,
windows["test_label"][ : , -1],
time_tracker=tt.get_data(),
)
...
eval_results_list.append(eval_results_single)

以上步骤为经典的深度学习实践流程,代码精简,这是因为主函数中不应该放置过多的与算法相关的代码,主函数应该是类似一个目录,告知我们每一步做了什么。而我们要做的就是按照这个目录去查询每个部分的具体代码。

模型架构

使用Ctrl+left click可以转到具体函数的源代码处。MSHTrans位于“networks/MSHTrans.py”中。
深度学习模型(以pytorch为例)的框架主要包括”__init__“以及”forward”两个主要函数,前者用于初始化模型,后者用于对于模型的正向传播,即模型数据处理的流程。
因此,阅读模型架构应从这两个主要的函数开始:

forward

模型初始化init就是其中每个小模块的初始化,forward做的就是将每个小模块串联起来。
优先阅读forward,因为更符合一般人的思考方式——按照数据流动的方向进行。
这里MSHTrans模型主要使用了编码器和解码器两部分,编码器将时间序列转化为多尺度融合logits,然后传入解码器重新转化为时间序列。这就是基于重构的异常检测方法。

1
2
3
4
5
def forward(self, x, hyper_graph_indicies, fused_hypergraph):
fused_logits = self.encoder(x, hyper_graph_indicies)
predict_logits = self.decoder(x, fused_logits, fused_hypergraph)

return predict_logits

在代码阅读过程中,我们将大量使用debug来查看数据的流动和转化过程。这会大大减少阅读时间。
下一部分将更加关注于具体的模块部分,但是模块部分就需要具体问题具体分析,所以如果对于MSHTrans模型感兴趣可以继续阅读,若仅仅想要学习代码阅读技巧,则可以跳过。


Cover image icon by Dewi Sari from Flaticon