Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com6年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python logging日志模块的详解

python logging日志模块的详解 日志级别 日志一共分成5个等级,从低到高分别是:DEBUG INFO WARNING ERROR CRITICAL。 DEBUG:详细的信...

win10系统Anaconda和Pycharm的Tensorflow2.0之CPU和GPU版本安装教程

win10系统Anaconda和Pycharm的Tensorflow2.0之CPU和GPU版本安装教程

tf2.0的三个优点: 1、方便搭建网络架构; 2、自动求导 3、GPU加速(便于大数据计算) 安装过程(概要提示) step1:安装annaconda3 step2:安装pycharm...

Python 专题一 函数的基础知识

Python 专题一 函数的基础知识

最近才开始学习Python语言,但就发现了它很多优势(如语言简洁、网络爬虫方面深有体会).我主要是通过《Python基础教程》和"51CTO学院 智普教育的python视频"学习,在看视...

Python3解释器知识点总结

Python3解释器知识点总结

Python3 解释器 Linux/Unix的系统上,一般默认的 python 版本为 2.x,我们可以将 python3.x 安装在 /usr/local/python3 目录中。...

python 监听salt job状态,并任务数据推送到redis中的方法

salt分发后,主动将已完成的任务数据推送到redis中,使用redis的生产者模式,进行消息传送 #coding=utf-8 import fnmatch,json,logging...