pytorch 在sequential中使用view来reshape的例子

yipeiwu_com6年前Python基础

pytorch中view是tensor方法,然而在sequential中包装的是nn.module的子类,

因此需要自己定义一个方法:

import torch.nn as nn
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args

 def forward(self, x):
  # 如果数据集最后一个batch样本数量小于定义的batch_batch大小,会出现mismatch问题。可以自己修改下,如只传入后面的shape,然后通过x.szie(0),来输入。
  return x.view(self.shape)
class Reshape(nn.Module):
 def __init__(self, *args):
  super(Reshape, self).__init__()
  self.shape = args
 def forward(self, x):
  return x.view((x.size(0),)+self.shape)

以上这篇pytorch 在sequential中使用view来reshape的例子就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python将string转换到float的实例方法

Python将string转换到float的实例方法

Python 如何转换string到float?简单几步,让你轻松解决。 打开软件,新建python项目,如图所示 右键菜单中创建.py文件,如图所示 步骤中文件输入代码如下...

Python实现将通信达.day文件读取为DataFrame

如下所示: import os import struct import pandas as pd def readTdxLdayFile(fname="C:\\TdxW_HuaT...

浅谈Pandas 排序之后索引的问题

如下所示: In [1]: import pandas as pd ...: df=pd.DataFrame({"a":[1,2,3,4,5],"b":[5,4,3,2,1]})...

python3获取两个日期之间所有日期,以及比较大小的实例

如下所示: import datetime #获取两个日期间的所有日期 def getEveryDay(begin_date,end_date): date_list = [...

Python利用matplotlib.pyplot绘图时如何设置坐标轴刻度

Python利用matplotlib.pyplot绘图时如何设置坐标轴刻度

前言 matplotlib.pyplot是一些命令行风格函数的集合,使matplotlib以类似于MATLAB的方式工作。每个pyplot函数对一幅图片(figure)做一些改动:比如创...