Tensorflow 自定义loss的情况下初始化部分变量方式

yipeiwu_com6年前Python基础

一般情况下,tensorflow里面变量初始化过程为:

  #variables ...........
  #..................... 
  init = tf.initialize_all_variables()
  sess.run(init)

这里 tf.initialize_all_variables() 会初始化所有的变量。

实际过程中,假设有a, b, c三个变量,其中a已经被初始化了,只想单独初始化b,c,那么:

  #variables ...
  ...
  init = tf.variables_initializer([b,c])
  sess.run(init)

此外,如果自行修改了optimizer,如下代码就会报错:

  #definition of variables a, b, c ...
  ....
  my_optimizer = tf.train.RMSProp(learning_rate = 0.1).minimize(my_cost)
  init = tf.variables_initializer([b,c])
  sess.run(init)

这是因为自己定义的optimizer会生成新的variables,但是在init里面并没有初始化,所以无法访问,会报错。解决方法如下:

  a = tf.Variables(...)      #line N
  temp = set(tf.all_variables()) 
  b = tf.Variables(...)
  c = tf.Variables(...) 
  #definition of my optimizer
  optimizer = tf.train.......
  init = tf.variables_initializer(set(tf.all_varialbles())-temp) # line M
  sess.run(init)

首先,temp = set(tf.all_variables()) 将该行(line N)代码之前的所有变量保存在temp中,接下来定义变量b, c,以及自定义的optimizer,然后 set(tf.all_varialbles()存储了改行(line M)之前的所有变量(包括optimizer生成的变量以及temp中所含的变量),set(tf.all_varialbles())-temp相减得到line N~M这几行定义的变量。

以上这篇Tensorflow 自定义loss的情况下初始化部分变量方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python基础之getpass模块详细介绍

Python基础之getpass模块详细介绍

本文主要给大家介绍了关于Python中getpass模块的相关内容,分享出来供大家参考学习,话不多说了,来一起看看详细的介绍: getpass模块提供了平台无关的在命令行下输入密码的方法...

DES加密解密算法之python实现版(图文并茂)

DES加密解密算法之python实现版(图文并茂)

一、DSE算法背景介绍 1. DES的采用 1979年,美国银行协会批准使用 1980年,美国国家标准局(ANSI)赞同DES作为私人使用的标准,称之为DEA(ANSI X.392) 1...

Python基于hashlib模块的文件MD5一致性加密验证示例

本文实例讲述了Python基于hashlib模块的文件MD5一致性加密验证。分享给大家供大家参考,具体如下: 使用hashlib模块,可对文件MD5一致性加密验证: #python...

使用Django2快速开发Web项目的详细步骤

使用Django2快速开发Web项目的详细步骤

Django 是一款基于 Python 编写并且采用 MVC 设计模式的开源的 Web 应用框架,早期是作为劳伦斯出版集团新闻网站的 CMS 内容管理系统而开发,后于 2005 年 7...

python之matplotlib学习绘制动态更新图实例代码

python之matplotlib学习绘制动态更新图实例代码

简介 通过定时器Timer触发事件,定时更新绘图,可以形成动态更新图片。下面的实例是学习《matplotlib for python developers》一文的笔记。 实现 实现代...