pytorch实现onehot编码转为普通label标签

yipeiwu_com6年前Python基础

label转onehot的很多,但是onehot转label的有点难找,所以就只能自己实现以下,用的topk函数,不知道有没有更好的实现

one_hot = torch.tensor([[0,0,1],[0,1,0],[0,1,0]])
print(one_hot)
label = torch.topk(one_hot, 1)[1].squeeze(1)
print(label)
tensor([[0, 0, 1],
[0, 1, 0],
[0, 1, 0]])

tensor([2, 1, 1])

以上这篇pytorch实现onehot编码转为普通label标签就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

关于pymysql模块的使用以及代码详解

pymysql模块的使用 查询一条数据fetchone() from pymysql import * conn = connect( host='127.0.0.1',...

解决PySide+Python子线程更新UI线程的问题

在我开发的系统,需要子线程去运行,然后把运行的结果发给UI线程,让UI线程知道运行的进度。 首先创建线程很简单 def newThread(self): d = Data() p...

Python进阶-函数默认参数(详解)

一、默认参数 python为了简化函数的调用,提供了默认参数机制: def pow(x, n = 2): r = 1 while n > 0: r *= x n...

python cv2读取rtsp实时码流按时生成连续视频文件方式

python cv2读取rtsp实时码流按时生成连续视频文件方式

我就废话不多说了,直接上代码吧! # coding: utf-8 import datetime import cv2 import os ip = '192.168.3.160...

TensorFlow模型保存和提取的方法

TensorFlow模型保存和提取的方法

一、TensorFlow模型保存和提取方法 1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save...