Pytorch 实现计算分类器准确率(总分类及子分类)

yipeiwu_com6年前Python基础

分类器平均准确率计算:

correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      correct += (prediction == labels).sum().float()
      total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:

correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
      images = Variable(images.cuda())
      labels = Variable(labels.cuda())

      output = model(images)

      prediction = torch.argmax(output, 1)
      res = prediction == labels
      for label_idx in range(len(labels)):
        label_single = label[label_idx]
        correct[label_single] += res[label_idx].item()
        total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
      try:
        acc = correct[acc_idx]/total[acc_idx]
      except:
        acc = 0
      finally:
        acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

基于Python3 逗号代码 和 字符图网格(详谈)

逗号代码 假定有下面这样的列表: spam=['apples','bananas','tofu',' cats'] 编写一个函数,它以一个列表值作为参数,返回一个字符串。该字符串包...

python二维键值数组生成转json的例子

今天出于需要,要将爬虫爬取的一些数据整理成二维数组,再编码成json字符串传入数据库 那么问题就来了,在php中这个过程很简便 ,类似这样: $arr[$key1][$key2]=...

python计算书页码的统计数字问题实例

本文实例讲述了python计算书页码的统计数字问题,是Python程序设计中一个比较典型的应用实例。分享给大家供大家参考。具体如下: 问题描述:对给定页码n,计算出全部页码中分别用到多少...

python3 实现的人人影视网站自动签到

这是一个自动化程度较高的程序,运行本程序后会从chrome中读取cookies用于登录人人影视签到, 并且会自动添加一个windows 任务计划,这个任务计划每天下午两点会执行本程序进行...

Windows系统配置python脚本开机启动的3种方法分享

Windows系统配置python脚本开机启动的3种方法分享

测试环境:windows Server 2003 R2 一、开始菜单启动项实现 用户必须登录才可执行。 测试脚本(python代码): 复制代码 代码如下: import time fo...