pytorch自定义二值化网络层方式

yipeiwu_com6年前Python基础

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python编写一个验证码图片数据标注GUI程序附源码

Python编写一个验证码图片数据标注GUI程序附源码

做验证码图片的识别,不论是使用传统的ORC技术,还是使用统计机器学习或者是使用深度学习神经网络,都少不了从网络上采集大量相关的验证码图片做数据集样本来进行训练。 采集验证码图片,可以直接...

Python中利用原始套接字进行网络编程的示例

在实验中需要自己构造单独的HTTP数据报文,而使用SOCK_STREAM进行发送数据包,需要进行完整的TCP交互。 因此想使用原始套接字进行编程,直接构造数据包,并在IP层进行发送,即采...

Python FFT合成波形的实例

使用Python numpy模块带的FFT函数合成矩形波和方波,增加对离散傅里叶变换的理解。 导入模块 import numpy as np import matplotlib.py...

python+selenium实现自动抢票功能实例代码

简介 什么是Selenium? Selenium是ThoughtWorks公司的一个强大的开源Web功能测试工具系列,采用Javascript来管理整个测试过程,包括读入测试套件、执行测...

Flask之请求钩子的实现

请求钩子 通过装饰器为一个模块添加请求钩子, 对当前模块的请求进行额外的处理. 比如权限验证. 说白了,就是在执行视图函数前后你可以进行一些处理,Flask使用装饰器为我们提供了注册...