pytorch自定义二值化网络层方式

yipeiwu_com6年前Python基础

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

对Python中range()函数和list的比较

使用Python的人都知道range()函数和list很方便,今天再用到他的时候发现了很多以前看到过但是忘记的细节。这里记录一下range()和list。 >>>...

解决phantomjs截图失败,phantom.exit位置的问题

刚刚学习使用phantomjs,根据网上帖子自己手动改了一个延时截图功能,发现延时功能就是不能执行,最后一点点排查出了问题。 看代码: var page = require('web...

Python通过matplotlib画双层饼图及环形图简单示例

Python通过matplotlib画双层饼图及环形图简单示例

(1) 饼图(pie),即在一个圆圈内分成几块,显示不同数据系列的占比大小,这也是我们在日常数据的图形展示中最常用的图形之一。 在python中常用matplotlib的pie来绘制,基...

神经网络理论基础及Python实现详解

神经网络理论基础及Python实现详解

一、多层前向神经网络 多层前向神经网络由三部分组成:输出层、隐藏层、输出层,每层由单元组成; 输入层由训练集的实例特征向量传入,经过连接结点的权重传入下一层,前一层的输出是下一层的输入;...

浅谈python中requests模块导入的问题

浅谈python中requests模块导入的问题

今天使用Pycharm来抓取网页图片时候,要导入requests模块,但是在pycharm中import requests 时候报错。 原因: python中还没有安装requests库...