最近设计了一个深度神经网络,虽然使用 logging
模块记录 Train loss
和 Valid loss
也算方便,但作图观察 loss 的变动趋势还是要更加直观的,考虑到网上貌似没有这方面的方法,这里放一下自己的记录方式。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator import time
class loss_tracer():
def __init__(self, outpath):
self.train_loss_list = [] self.test_loss_list = [] self.outpath = outpath timestamp = time.time() self.starttime = time.strftime('%Y_%m_%d_%H_%M', time.localtime(timestamp)) def __call__(self, trainloss, testloss, plot=False, length=6, width=3):
self.train_loss_list.append(trainloss) self.test_loss_list.append(testloss) if plot: self.plot(length, width)
def plot(self, length=6, width=3): plt.figure(figsize=(length,width)) epoch_list = range(1, len(self.train_loss_list) + 1) plt.plot(epoch_list, self.train_loss_list, color='blue', label='Train loss') plt.plot(epoch_list, self.test_loss_list, color='red', label='Valid loss') plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True)) plt.legend() plt.tight_layout() plt.savefig(f'{self.outpath}/{self.starttime}_train_loss_valid_loss.png') plt.close()
|
该方法要求已经安装 matplotlib
库,以下是一些具体的操作方式:
1 2 3 4 5 6 7 8 9
| lt = loss_tracer('./')
lt(train_loss, valid_loss) lt(train_loss, valid_loss, plot=True)
lt.plot()
|
需要注意的一些地方:
- 频繁的进行绘制会导致性能开销过大,如果一个 epoch 用时并不长不建议设置
plot=True
。反之,如果一个 epoch 需要的时间较长,那么在每个 epoch 后储存新的 loss 图也是很好的选择。
- 该图的名称为
年月日时分_train_loss_valid_loss.png
,如果有需求可自行修改。
- 该图默认长度为
6
宽度为 3
,有需求可在参数中自行修改。
简单的实战示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
| import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator import numpy as np import torch import torch.nn as nn import time
class loss_tracer():
def __init__(self, outpath):
self.train_loss_list = [] self.test_loss_list = [] self.outpath = outpath timestamp = time.time() self.starttime = time.strftime('%Y_%m_%d_%H_%M', time.localtime(timestamp)) def __call__(self, trainloss, testloss, plot=False, length=6, width=3):
self.train_loss_list.append(trainloss) self.test_loss_list.append(testloss) if plot: self.plot(length, width)
def plot(self, length=6, width=3): plt.figure(figsize=(length,width)) epoch_list = range(1, len(self.train_loss_list) + 1) plt.plot(epoch_list, self.train_loss_list, color='blue', label='Train loss') plt.plot(epoch_list, self.test_loss_list, color='red', label='Valid loss') plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True)) plt.legend() plt.tight_layout() plt.savefig(f'{self.outpath}/{self.starttime}_train_loss_valid_loss.png') plt.close()
np.random.seed(114514) a_train = np.random.randn(400) b_train = np.random.randn(400) c_train = a_train*0.7 + b_train*0.6 + 0.5 + np.random.rand(400)
a_valid = np.random.randn(100) b_valid = np.random.randn(100) c_valid = a_valid*0.7 + b_valid*0.6 + 0.5 + np.random.rand(100)
a_train = torch.from_numpy(a_train) b_train = torch.from_numpy(b_train) c_train = torch.from_numpy(c_train) a_valid = torch.from_numpy(a_valid) b_valid = torch.from_numpy(b_valid) c_valid = torch.from_numpy(c_valid)
train_data = torch.stack((a_train, b_train), dim=1) valid_data = torch.stack((a_valid, b_valid), dim=1)
net = nn.Sequential(nn.Linear(2, 1)) opt = torch.optim.SGD(net.parameters(), lr=0.1) lf = nn.MSELoss() lt = loss_tracer('./')
for epoch in range(100):
opt.zero_grad() train_outputs = net(train_data.float()) train_loss = lf(train_outputs, c_train.float().unsqueeze(1)) train_loss.backward() opt.step()
with torch.no_grad(): valid_outputs = net(valid_data.float()) valid_loss = lf(valid_outputs, c_valid.float().unsqueeze(1))
lt(train_loss.item(), valid_loss.item())
lt.plot()
|
最后生成的图片实例:
事实证明,只要数据的模式足够简单且密集,想要过拟合其实也挺困难的。
最后提供一种上述函数的变体,可以实现每 x 个 epoch 绘图 1 次的功能:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| import matplotlib.pyplot as plt from matplotlib.ticker import MaxNLocator import time
class loss_tracer():
def __init__(self, outpath, plotepoch=5):
self.train_loss_list = [] self.test_loss_list = [] self.plotepoch = plotepoch self.outpath = outpath timestamp = time.time() self.starttime = time.strftime('%Y_%m_%d_%H_%M', time.localtime(timestamp)) def __call__(self, trainloss, testloss, plot=True, length=6, width=3):
self.train_loss_list.append(trainloss) self.test_loss_list.append(testloss) losslist_len = len(self.train_loss_list) if losslist_len % self.plotepoch == 0 and plot: self.plot(length, width)
def plot(self, length=6, width=3): plt.figure(figsize=(length,width)) epoch_list = range(1, len(self.train_loss_list) + 1) plt.plot(epoch_list, self.train_loss_list, color='blue', label='Train loss') plt.plot(epoch_list, self.test_loss_list, color='red', label='Valid loss') plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True)) plt.legend() plt.tight_layout() plt.savefig(f'{self.outpath}/{self.starttime}_train_loss_valid_loss.png') plt.close()
|
操作方法和之前一致,如果在新添 loss 的时候不需要绘图,则在调取实例时指定 plot=False
即可。