Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com5年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

在PYQT5中QscrollArea(滚动条)的使用方法

在PYQT5中QscrollArea(滚动条)的使用方法

如下所示: import sys from PyQt5.QtWidgets import * class MainWindow(QMainWindow): def __in...

python+OpenCV实现车牌号码识别

python+OpenCV实现车牌号码识别

基于python+OpenCV的车牌号码识别,供大家参考,具体内容如下 车牌识别行业已具备一定的市场规模,在电子警察、公路卡口、停车场、商业管理、汽修服务等领域已取得了部分应用。一个典型...

python client使用http post 到server端的代码

复制代码 代码如下:import urllib, httplib  import utils  import json     ...

django 按时间范围查询数据库实例代码

从前台中获得时间范围,在django后台处理request中数据,完成format,按照范围调用函数查询数据库。 介绍一个简单的功能,就是从web表单里获取用户指定的时间范围,然后在数据...

Pytorch保存模型用于测试和用于继续训练的区别详解

保存模型 保存模型仅仅是为了测试的时候,只需要 torch.save(model.state_dict, path) path 为保存的路径 但是有时候模型及数据太多,难以一次性训...