关于pytorch处理类别不平衡的问题

yipeiwu_com5年前Python基础

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

对python借助百度云API对评论进行观点抽取的方法详解

对python借助百度云API对评论进行观点抽取的方法详解

通过百度云API接口抽取得到产品评论的观点,也掠去了很多评论中无用的内容以及符号,为后续进行文本主题挖掘或者规则的提取提供基础。 工具 1、百度云账号,申请应用接口(自然语言处理) 2...

Python paramiko模块的使用示例

paramiko模块提供了ssh及sft进行远程登录服务器执行命令和上传下载文件的功能。这是一个第三方的软件包,使用之前需要安装。 1 基于用户名和密码的 sshclient 方式登录...

python字典DICT类型合并详解

python字典DICT类型合并详解

本文为大家分享了python字典DICT类型合并的方法,供大家参考,具体内容如下 我要的字典的键值有些是数据库中表的字段名, 但是有些却不是, 我需要把它们整合到一起, 因此有些这篇文章...

python中WSGI是什么,Python应用WSGI详解

为了让大家更好的对python中WSGI有更好的理解,我们先从最简单的认识WSGI着手,然后介绍一下WSGI几个经常使用到的接口,了解基本的用法和功能,最后,我们通过实例了解一下WSGI...

python twilio模块实现发送手机短信功能

python twilio模块实现发送手机短信功能

前排提示:这个模块不是用于对陌生人进行短信轰炸和电话骚扰的,这个模块也没有这个功能,如果是抱着这个心态来的,可以关闭网页了 语言:python 步骤一:安装twilio模块 pip in...