Tensorflow 自定义loss的情况下初始化部分变量方式

yipeiwu_com6年前Python基础

一般情况下,tensorflow里面变量初始化过程为:

  #variables ...........
  #..................... 
  init = tf.initialize_all_variables()
  sess.run(init)

这里 tf.initialize_all_variables() 会初始化所有的变量。

实际过程中,假设有a, b, c三个变量,其中a已经被初始化了,只想单独初始化b,c,那么:

  #variables ...
  ...
  init = tf.variables_initializer([b,c])
  sess.run(init)

此外,如果自行修改了optimizer,如下代码就会报错:

  #definition of variables a, b, c ...
  ....
  my_optimizer = tf.train.RMSProp(learning_rate = 0.1).minimize(my_cost)
  init = tf.variables_initializer([b,c])
  sess.run(init)

这是因为自己定义的optimizer会生成新的variables,但是在init里面并没有初始化,所以无法访问,会报错。解决方法如下:

  a = tf.Variables(...)      #line N
  temp = set(tf.all_variables()) 
  b = tf.Variables(...)
  c = tf.Variables(...) 
  #definition of my optimizer
  optimizer = tf.train.......
  init = tf.variables_initializer(set(tf.all_varialbles())-temp) # line M
  sess.run(init)

首先,temp = set(tf.all_variables()) 将该行(line N)代码之前的所有变量保存在temp中,接下来定义变量b, c,以及自定义的optimizer,然后 set(tf.all_varialbles()存储了改行(line M)之前的所有变量(包括optimizer生成的变量以及temp中所含的变量),set(tf.all_varialbles())-temp相减得到line N~M这几行定义的变量。

以上这篇Tensorflow 自定义loss的情况下初始化部分变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python创建日历实例

本文讲述了Python创建日历的方法,与以往不同的是,本文实例不使用Python提供的calendar实现,相信对大家的Python程序设计有一定的借鉴价值。 此程序在windows下测...

python 字典(dict)遍历的四种方法性能测试报告

python中,遍历dict的方法有四种。但这四种遍历的性能如何呢?我做了如下的测试 l = [(x,x) for x in xrange(10000)] d = dict(l)...

django中使用事务及接入支付宝支付功能

django中使用事务及接入支付宝支付功能

之前一直想记录一下在项目中使用到的事务以及支付宝支付功能,自己一直犯懒没有完,趁今天有点兴致,在这记录一下。 商城项目必备的就是支付订单的功能,所以就会涉及到订单的保存以及支付接口的引入...

python读取文本中数据并转化为DataFrame的实例

python读取文本中数据并转化为DataFrame的实例

在技术问答中看到一个这样的问题,感觉相对比较常见,就单开一篇文章写下来。 从纯文本格式文件 “file_in”中读取数据,格式如下: 需要输出成“file_out”,格式如下: 数据...

python中pandas.DataFrame排除特定行方法示例

前言 大家在使用Python进行数据分析时,经常要使用到的一个数据结构就是pandas的DataFrame,关于python中pandas.DataFrame的基本操作,大家可以查看这篇...