Python绘制loss曲线和准确率曲线实例代码

目录
  • 引言
  • 一、数据读取与存储部分
  • 二、绘制 loss 曲线
  • 三、绘制准确率曲线
  • 总结

引言

使用 python 绘制网络训练过程中的的 loss 曲线以及准确率变化曲线,这里的主要思想就时先把想要的损失值以及准确率值保存下来,保存到 .txt 文件中,待网络训练结束,我们再拿这存储的数据绘制各种曲线。

其大致步骤为:数据读取与存储 - > loss曲线绘制 - > 准确率曲线绘制

一、数据读取与存储部分

我们首先要得到训练时的数据,以损失值为例,网络每迭代一次都会产生相应的 loss,那么我们就把每一次的损失值都存储下来,存储到列表,保存到 .txt 文件中。保存的文件如下图所示:

[1.3817585706710815, 1.8422836065292358, 1.1619832515716553, 0.5217241644859314, 0.5221078991889954, 1.3544578552246094, 1.3334463834762573, 1.3866571187973022, 0.7603049278259277]

上图为部分损失值,根据迭代次数而异,要是迭代了1万次,这里就会有1万个损失值。
而准确率值是每一个 epoch 产生一个值,要是训练100个epoch,就有100个准确率值。

(那么问题来了,这里的损失值是怎么保存到文件中的呢? 很少有人讲这个,也有一些小伙伴们来咨询,这里就统一记录一下,包括损失值和准确率值。)

首先,找到网络训练代码,就是项目中的 main.py,或者 train.py ,在文件里先找到训练部分,里面经常会有这样一行代码:

for epoch in range(resume_epoch, num_epochs):   # 就是这一行
	####
	...
	loss = criterion(outputs, labels.long())              # 损失样例
	...
    epoch_acc = running_corrects.double() / trainval_sizes[phase]    # 准确率样例
    ...
    ###

从这一行开始就是训练部分了,往下会找到类似的这两句代码,就是损失值和准确率值了。

这时候将以下代码加入源代码就可以了:

train_loss = []
train_acc = []
for epoch in range(resume_epoch, num_epochs):          # 就是这一行
	###
	...
	loss = criterion(outputs, labels.long())           # 损失样例
	train_loss.append(loss.item())                     # 损失加入到列表中
	...
	epoch_acc = running_corrects.double() / trainval_sizes[phase]    # 准确率样例
	train_acc.append(epoch_acc.item())                 # 准确率加入到列表中
	...
with open("./train_loss.txt", 'w') as train_los:
    train_los.write(str(train_loss))

with open("./train_acc.txt", 'w') as train_ac:
     train_ac.write(str(train_acc))

这样就算完成了损失值和准确率值的数据存储了!

二、绘制 loss 曲线

主要需要 numpy 库和 matplotlib 库,如果不会安装可以自行百度,很简单。

首先,将 .txt 文件中的存储的数据读取进来,以下是读取函数:

import numpy as np

# 读取存储为txt文件的数据
def data_read(dir_path):
    with open(dir_path, "r") as f:
        raw_data = f.read()
        data = raw_data[1:-1].split(", ")   # [-1:1]是为了去除文件中的前后中括号"[]"

    return np.asfarray(data, float)

然后,就是绘制 loss 曲线部分:

if __name__ == "__main__":

	train_loss_path = r"E:\relate_code\Gaitpart-master\train_loss.txt"   # 存储文件路径

	y_train_loss = data_read(train_loss_path)        # loss值,即y轴
	x_train_loss = range(len(y_train_loss))			 # loss的数量,即x轴

	plt.figure()

    # 去除顶部和右边框框
    ax = plt.axes()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.xlabel('iters')    # x轴标签
    plt.ylabel('loss')     # y轴标签

	# 以x_train_loss为横坐标,y_train_loss为纵坐标,曲线宽度为1,实线,增加标签,训练损失,
	# 默认颜色,如果想更改颜色,可以增加参数color='red',这是红色。
    plt.plot(x_train_loss, y_train_loss, linewidth=1, linestyle="solid", label="train loss")
    plt.legend()
    plt.title('Loss curve')
    plt.show()

这样就算把损失图像画出来了!如下:

三、绘制准确率曲线

有了上面的基础,这就简单很多了。
只是有一点要记住,上面的x轴是迭代次数,这里的是训练轮次 epoch。

if __name__ == "__main__":

	train_acc_path = r"E:\relate_code\Gaitpart-master\train_acc.txt"   # 存储文件路径

	y_train_acc = data_read(train_acc_path)       # 训练准确率值,即y轴
	x_train_acc = range(len(y_train_acc))			 # 训练阶段准确率的数量,即x轴

	plt.figure()

    # 去除顶部和右边框框
    ax = plt.axes()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.xlabel('epochs')    # x轴标签
    plt.ylabel('accuracy')     # y轴标签

	# 以x_train_acc为横坐标,y_train_acc为纵坐标,曲线宽度为1,实线,增加标签,训练损失,
	# 增加参数color='red',这是红色。
    plt.plot(x_train_acc, y_train_acc, color='red',linewidth=1, linestyle="solid", label="train acc")
    plt.legend()
    plt.title('Accuracy curve')
    plt.show()

这样就把准确率变化曲线画出来了!如下:

以下是完整代码,以绘制准确率曲线为例,并且将x轴换成了iters,和损失曲线保持一致,供参考:

import numpy as np
import matplotlib.pyplot as plt

# 读取存储为txt文件的数据
def data_read(dir_path):
    with open(dir_path, "r") as f:
        raw_data = f.read()
        data = raw_data[1:-1].split(", ")

    return np.asfarray(data, float)

# 不同长度数据,统一为一个标准,倍乘x轴
def multiple_equal(x, y):
    x_len = len(x)
    y_len = len(y)
    times = x_len/y_len
    y_times = [i * times for i in y]
    return y_times

if __name__ == "__main__":

    train_loss_path = r"E:\relate_code\Gaitpart-master\file_txt\train_loss.txt"
    train_acc_path = r"E:\relate_code\Gaitpart-master\train_acc.txt"

    y_train_loss = data_read(train_loss_path)
    y_train_acc = data_read(train_acc_path)

    x_train_loss = range(len(y_train_loss))
    x_train_acc = multiple_equal(x_train_loss, range(len(y_train_acc)))

    plt.figure()

    # 去除顶部和右边框框
    ax = plt.axes()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    plt.xlabel('iters')
    plt.ylabel('accuracy')

    # plt.plot(x_train_loss, y_train_loss, linewidth=1, linestyle="solid", label="train loss")
    plt.plot(x_train_acc, y_train_acc,  color='red', linestyle="solid", label="train accuracy")
    plt.legend()

    plt.title('Accuracy curve')
    plt.show()

总结

到此这篇关于Python绘制loss曲线和准确率曲线的文章就介绍到这了,更多相关Python绘制loss曲线 准确率曲线内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

(0)

相关推荐

  • python matlibplot绘制多条曲线图

    这里我利用的是matplotlib.pyplot.plot的工具来绘制折线图,这里先给出一个段代码和结果图: # -*- coding: UTF-8 -*- import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt #这里导入你自己的数据 #...... #...... #x_axix,train_pn_dis这些都是长度相同的list() #开始画图 sub_axix = filter(lambda

  • 利用python绘制数据曲线图的实现

    "在举国上下万众一心.众志成城做好新冠肺炎疫情防控工作的特殊时刻,我们不能亲临主战场,但我们能坚持在大战中坚定信心.不负韶华." 1.爬取新闻保存为json文件,并将绘图所需数据保存至数据库 数据库表结构: 代码部分: import pymysql import re import sys,urllib,json from urllib import request from datetime import datetime import pandas as pd Today=date

  • python绘制多个曲线的折线图

    这篇文章利用的是matplotlib.pyplot.plot的工具来绘制折线图,这里先给出一个段代码和结果图: # -*- coding: UTF-8 -*- import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt #这里导入你自己的数据 #...... #...... #x_axix,train_pn_dis这些都是长度相同的list() #开始画图 sub_axix = filter(lambda

  • python 实现将多条曲线画在一幅图上的方法

    如下所示: # -*- coding: utf-8 -*- """ Created on Thu Jun 07 09:17:40 2018 @author: yjp """ import matplotlib.pyplot as plt import numpy as np from matplotlib.ticker import MultipleLocator, FormatStrFormatter y0 = [] y1 = [] y2 =

  • 如何通过python画loss曲线的方法

    1. 首先导入一些python画图的包,读取txt文件,假设我现在有两个模型训练结果的records.txt文件 import numpy as np import matplotlib.pyplot as plt import pylab as pl from mpl_toolkits.axes_grid1.inset_locator import inset_axes data1_loss =np.loadtxt("valid_RCSCA_records.txt") data2_l

  • Python绘制全球疫情变化地图的实例代码

    目前全球疫情仍然比较严重,为了能清晰地看到疫情爆发以来至现在全球疫情的变化趋势,我绘制了一张疫情变化地图. 废话不多说,先上图 下面就来重点介绍下上面这张图的绘制过程,主要分为以下三个步骤: 数据收集 数据处理 画图 下面一个一个来说. 数据收集 这是万里长城的第一步,俗话说"巧妇难为无米之炊",既然是变化图,当然需要每个国家.每天的现有确诊病例数.好在现在各大网站都有疫情相关的专题页,我们可以直接抓数据.以网易为例 我们选择 XHR,重新刷新下网页可以看到有几个接口,其中 list-

  • Python绘制loss曲线和准确率曲线实例代码

    目录 引言 一.数据读取与存储部分 二.绘制 loss 曲线 三.绘制准确率曲线 总结 引言 使用 python 绘制网络训练过程中的的 loss 曲线以及准确率变化曲线,这里的主要思想就时先把想要的损失值以及准确率值保存下来,保存到 .txt 文件中,待网络训练结束,我们再拿这存储的数据绘制各种曲线. 其大致步骤为:数据读取与存储 - > loss曲线绘制 - > 准确率曲线绘制 一.数据读取与存储部分 我们首先要得到训练时的数据,以损失值为例,网络每迭代一次都会产生相应的 loss,那么我

  • python绘制直方图和密度图的实例

    对于pandas的dataframe,绘制直方图方法如下: //pdf是pandas的dataframe, delta_time是其中一列 //xlim是x轴的范围,bins是分桶个数 pdf.delta_time.plot(kind='hist', xlim=(-50,300), bins=500) 对于pandas的dataframe,绘制概率密度图方法如下: //pdf是pandas的dataframe, delta_time是其中一列 pdf.delta_time.dropna().pl

  • python 把数据 json格式输出的实例代码

    有个要求需要在python的标准输出时候显示json格式数据,如果缩进显示查看数据效果会很好,这里使用json的包会有很多操作 import json date = {u'versions': [{u'status': u'CURRENT', u'id': u'v2.3', u'links': [{u'href': u'http://controller:9292/v2/', u'rel': u'self'}]}, {u'status': u'SUPPORTED', u'id': u'v2.2'

  • python 实现自动远程登陆scp文件实例代码

     python 实现自动远程登陆scp文件实例代码 实现实例代码: #!/usr/bin/expect if {$argc!=3} { send_user "Usage: $argv0 {path1} {path2} {Password}\n\n" exit } set path1 [lindex $argv 0] set path2 [lindex $argv 1] set Password [lindex $argv 2] spawn scp ${path1} ${path2} e

  • python将ansible配置转为json格式实例代码

    python将ansible配置转为json格式实例代码 ansible的配置文件举例如下,这种配置文件不利于在前端的展现,因此,我们用一段简单的代码将ansible的配置文件转为json格式的: [webserver] 192.168.204.70 192.168.204.71 [dbserver] 192.168.204.72 192.168.204.73 192.168.204.75 [proxy] 192.168.204.76 192.168.204.77 192.168.204.78

  • Python+tkinter模拟“记住我”自动登录实例代码

    本文分享的代码主要是通过Python+tkinter模拟"记住我"自动登录的功能,具体介绍如下. 基本思路:如果某次登录成功,则创建临时文件记录有关信息,每次启动程序时尝试自动获取上次登录成功的信息并自动编写.本文主要演示思路,可根据实际系统中的需要进行改写,例如读取数据库并验证用户名和密码是否正确.对用户名和密码进行本地加密存储等等. import tkinter import tkinter.messagebox import os import os.path # 获取Windo

  • python与sqlite3实现解密chrome cookie实例代码

    本文研究的主要问题:有一个解密chrome cookie的事情,google出了代码,却不能正常执行,原因在于sqlite3的版本太低,虽然我切换到了python3.5的环境,但sqlite3的版本也只有3.6. google了许久,终于找到方法: 1. 进入页面 http://www6.atomicorp.com/channels/atomic/centos/6/x86_64/RPMS/ 2. 下载 atomic-sqlite-sqlite-3.8.5-2.el6.art.x86_64.rpm

  • python批量替换页眉页脚实例代码

    简介 本文分享的实例代码主要通过python语言实现批量替换页眉页脚的操作功能,具体如下. 代码 #!/usr/bin/env python # -*- coding: utf-8 -*- import win32com,os,sys,re from win32com.client import Dispatch, constants # 打开新的文件 suoyou = os.listdir('d:\\daizhuan') #print suoyou for i in suoyou: wenji

  • python导出hive数据表的schema实例代码

    本文研究的主要问题是python语言导出hive数据表的schema,分享了实现代码,具体如下. 为了避免运营提出无穷无尽的查询需求,我们决定将有查询价值的数据从mysql导入hive中,让他们使用HUE这个开源工具进行查询.想必他们对表结构不甚了解,还需要为之提供一个表结构说明,于是编写了一个脚本,从hive数据库中将每张表的字段即类型查询出来,代码如下: #coding=utf-8 import pyhs2 from xlwt import * hiveconn = pyhs2.connec

随机推荐