最近设计了一个深度神经网络,虽然使用 logging 模块记录 Train lossValid loss 也算方便,但作图观察 loss 的变动趋势还是要更加直观的,考虑到网上貌似没有这方面的方法,这里放一下自己的记录方式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import time

class loss_tracer():

def __init__(self, outpath):

self.train_loss_list = []
self.test_loss_list = []
self.outpath = outpath
timestamp = time.time()
self.starttime = time.strftime('%Y_%m_%d_%H_%M', time.localtime(timestamp))

def __call__(self, trainloss, testloss, plot=False, length=6, width=3):

self.train_loss_list.append(trainloss)
self.test_loss_list.append(testloss)
if plot:
self.plot(length, width)

def plot(self, length=6, width=3):

plt.figure(figsize=(length,width))
epoch_list = range(1, len(self.train_loss_list) + 1)
plt.plot(epoch_list, self.train_loss_list, color='blue', label='Train loss')
plt.plot(epoch_list, self.test_loss_list, color='red', label='Valid loss')
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.legend()
plt.tight_layout()
plt.savefig(f'{self.outpath}/{self.starttime}_train_loss_valid_loss.png')
plt.close()

该方法要求已经安装 matplotlib 库,以下是一些具体的操作方式:

1
2
3
4
5
6
7
8
9
# 创造一个追踪器实例
lt = loss_tracer('./') # 自行设置一个图片存放路径

# 每次计算出 loss 后进行储存
lt(train_loss, valid_loss) # 不能传入张量哦
lt(train_loss, valid_loss, plot=True) # 储存后进行绘制

# 利用已储存的 loss 列表进行绘图
lt.plot()

需要注意的一些地方:

  • 频繁的进行绘制会导致性能开销过大,如果一个 epoch 用时并不长不建议设置 plot=True 。反之,如果一个 epoch 需要的时间较长,那么在每个 epoch 后储存新的 loss 图也是很好的选择。
  • 该图的名称为 年月日时分_train_loss_valid_loss.png,如果有需求可自行修改。
  • 该图默认长度为 6 宽度为 3,有需求可在参数中自行修改。

简单的实战示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import numpy as np
import torch
import torch.nn as nn
import time

class loss_tracer():

def __init__(self, outpath):

self.train_loss_list = []
self.test_loss_list = []
self.outpath = outpath
timestamp = time.time()
self.starttime = time.strftime('%Y_%m_%d_%H_%M', time.localtime(timestamp))

def __call__(self, trainloss, testloss, plot=False, length=6, width=3):

self.train_loss_list.append(trainloss)
self.test_loss_list.append(testloss)
if plot:
self.plot(length, width)

def plot(self, length=6, width=3):

plt.figure(figsize=(length,width))
epoch_list = range(1, len(self.train_loss_list) + 1)
plt.plot(epoch_list, self.train_loss_list, color='blue', label='Train loss')
plt.plot(epoch_list, self.test_loss_list, color='red', label='Valid loss')
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.legend()
plt.tight_layout()
plt.savefig(f'{self.outpath}/{self.starttime}_train_loss_valid_loss.png')
plt.close()

np.random.seed(114514)
a_train = np.random.randn(400)
b_train = np.random.randn(400)
c_train = a_train*0.7 + b_train*0.6 + 0.5 + np.random.rand(400)

a_valid = np.random.randn(100)
b_valid = np.random.randn(100)
c_valid = a_valid*0.7 + b_valid*0.6 + 0.5 + np.random.rand(100)

a_train = torch.from_numpy(a_train)
b_train = torch.from_numpy(b_train)
c_train = torch.from_numpy(c_train)
a_valid = torch.from_numpy(a_valid)
b_valid = torch.from_numpy(b_valid)
c_valid = torch.from_numpy(c_valid)

train_data = torch.stack((a_train, b_train), dim=1)
valid_data = torch.stack((a_valid, b_valid), dim=1)

net = nn.Sequential(nn.Linear(2, 1))
opt = torch.optim.SGD(net.parameters(), lr=0.1)
lf = nn.MSELoss()
lt = loss_tracer('./') # 假设损失追踪图片在脚本运行目录下生成

# 训练 100 个 epoch
for epoch in range(100):

opt.zero_grad()

train_outputs = net(train_data.float())
train_loss = lf(train_outputs, c_train.float().unsqueeze(1))

train_loss.backward()
opt.step()

with torch.no_grad():
valid_outputs = net(valid_data.float())
valid_loss = lf(valid_outputs, c_valid.float().unsqueeze(1))

lt(train_loss.item(), valid_loss.item())

lt.plot()

最后生成的图片实例:

事实证明,只要数据的模式足够简单且密集,想要过拟合其实也挺困难的。

最后提供一种上述函数的变体,可以实现每 x 个 epoch 绘图 1 次的功能:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import time

class loss_tracer():

def __init__(self, outpath, plotepoch=5):

self.train_loss_list = []
self.test_loss_list = []
self.plotepoch = plotepoch
self.outpath = outpath
timestamp = time.time()
self.starttime = time.strftime('%Y_%m_%d_%H_%M', time.localtime(timestamp))

def __call__(self, trainloss, testloss, plot=True, length=6, width=3):

self.train_loss_list.append(trainloss)
self.test_loss_list.append(testloss)

losslist_len = len(self.train_loss_list)
if losslist_len % self.plotepoch == 0 and plot:
self.plot(length, width)

def plot(self, length=6, width=3):

plt.figure(figsize=(length,width))
epoch_list = range(1, len(self.train_loss_list) + 1)
plt.plot(epoch_list, self.train_loss_list, color='blue', label='Train loss')
plt.plot(epoch_list, self.test_loss_list, color='red', label='Valid loss')
plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
plt.legend()
plt.tight_layout()
plt.savefig(f'{self.outpath}/{self.starttime}_train_loss_valid_loss.png')
plt.close()

操作方法和之前一致,如果在新添 loss 的时候不需要绘图,则在调取实例时指定 plot=False 即可。