读取训练loss趋势并且对比绝对差值

#python #re #numpy #matplotlib #ai

import re
import matplotlib.pyplot as plt
import numpy as np
def read_loss(file_path):
	str1 = "'loss': "
	str2 = ", 'learning_rate'"
	numbers = []
	pattern = r'[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?' #匹配浮点数
	with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
		for line in file:
			match = re.search(f'{str1}({pattern}){str2}', line)
			if match:
				numbers.append(float(match.group(1)))
	return numbers
def read_epoch(file_path)
	str1 = "'epoch': "
	str2 = "}"
	numbers = []
	pattern = r'[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?' #匹配浮点数
	with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
		for line in file:
			match = re.search(f'{str1}({pattern}){str2}', line)
			if match:
				numbers.append(float(match.group(1)))
	return numbers
def av_time(file_path)
	str1 = "'time cost': "
	numbers = []
	pattern = r'[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?' #匹配浮点数
	with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
		for line in file:
			match = re.search(f'{str1}({pattern})', line)
			if match:
				numbers.append(float(match.group(1)))
	nums =  np.array(numbers)
	return nums.mean(nums)
loss_value_a800 = read_loss(.\\train_7B_a800.log)
loss_value_npu = read_loss(.\\train_7B_npu.log)
epoch = read_epoch(.\\train_7B_a800.log)
epoch.pop() #最后一个是汇总,去除以保持数量一致

# first picture
plt.figure()
plt.plot(epoch,loss_value_a800,color='b',label='A800')
plt.plot(epoch,loss_value_npu,color='g',label='npu')
plt.title('Loss Trend')
plt.xlabel('Epochs')
plt.ylabel('Loss')

# second picture
plt.figure()
a800 = np.array(loss_value_a800)
npu = np.array(loss_value_npu)
lossDiff = np.abs(a800-npu)
plt.plot(epoch,lossDiff,color='r',label='diff')
plt.title('Absolute Loss Diff Trend')
plt.xlabel('Epochs')
plt.ylabel('Absolute Loss Diff')
plt.ylim(top=0.1)

# some addition info
a800_t = av_time(.\\train_7B_a800.log)
npu_t = av_time(.\\train_7B_npu.log)
print(f'a800 mean step time: {a800_t}')
print(f'npu mean step time: {npu_t}')
proportion = np.sum(lossDiff<0.001)/len(lossDiff)
print(f"小于0.001的比例: {proportion}")

plt.show()