pytorch自定义二值化网络层方式

yipeiwu_com6年前Python基础

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:

import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    a = torch.ones_like(input)
    b = -torch.ones_like(input)
    output = torch.where(input>=0,a,b)
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_abs = torch.abs(input)
    ones = torch.ones_like(input)
    zeros = torch.zeros_like(input)
    input_grad = torch.where(input_abs<=1,ones, zeros)
    return input_grad

定义一个module

class BinarizedModule(nn.Module):
  def __init__(self):
    super(BinarizedModule, self).__init__()
    self.BF = BinarizedF()
  def forward(self,input):
    print(input.shape)
    output =self.BF(input)
    return output

进行测试

a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s

class BinarizedF(Function):
  def forward(self, input):
    self.save_for_backward(input)
    output = torch.ones_like(input)
    output[input<0] = -1
    return output
  def backward(self, output_grad):
    input, = self.saved_tensors
    input_grad = output_grad.clone()
    input_abs = torch.abs(input)
    input_grad[input_abs>1] = 0
    return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

python将unicode转为str的方法

问题:  将u'\u810f\u4e71'转换为'\u810f\u4e71'   方法:  s_unicode = u'\u810f\u4e...

Python实现基本数据结构中栈的操作示例

Python实现基本数据结构中栈的操作示例

本文实例讲述了Python实现基本数据结构中栈的操作。分享给大家供大家参考,具体如下: #! /usr/bin/env python #coding=utf-8 #Python实现基...

python中函数默认值使用注意点详解

当在函数中定义默认值时,值初始化只会进行一次,就是执行到def methodname时执行。看下面代码: from datetime import datetime def te...

python 为什么说eval要慎用

eval前言 In [1]: eval("2+3") Out[1]: 5 In [2]: eval('[x for x in range(9)]') Out[2]: [0, 1,...

详解C++编程中一元运算符的重载

可重载的一元运算符如下: !(逻辑“非”) &(取址) ~(二进制反码) *(取消指针引用) +(一元加) -(一元求反) ++(递增) --(递减)...