tensorflow 只恢复部分模型参数的实例

yipeiwu_com6年前Python基础

我就废话不多说了,直接上代码吧!

import tensorflow as tf

def model_1():
  with tf.variable_scope("var_a"):
    a = tf.Variable(initial_value=[1, 2, 3], name="a")

  vars = [var for var in tf.trainable_variables() if var.name.startswith("var_a")]
  print(len(vars))
  return vars

def model_2():

  vars1 = model_1()

  with tf.variable_scope("var_b"):
    a = tf.Variable(initial_value=[1, 2, 3], name="a")

  vars2 = [var for var in tf.trainable_variables() if var.name.startswith("var")]
  print(len(vars2))
  return vars1


def pretrain_model1():
  print("-------- model 1 ------")
  vars = model_1()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    saver.save(sess, "./model.ckpt")

def train_model2():
  print("-------- model 2 ------")

  model1_vars = model_2()

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(var_list=model1_vars)
    saver.restore(sess, "./model.ckpt")
    vars = sess.run([model1_vars])
    for var in vars:
      print(var)

step = 2
if step == 1:
  pretrain_model1()
else:
  train_model2()

以上这篇tensorflow 只恢复部分模型参数的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持【听图阁-专注于Python设计】。

相关文章

Python实现的一个简单LRU cache

起因:我的同事需要一个固定大小的cache,如果记录在cache中,直接从cache中读取,否则从数据库中读取。python的dict 是一个非常简单的cache,但是由于数据量很大,内...

Python代码缩进和测试模块示例详解

前言 Python代码缩进和测试模块是大家学习python必不可少的一部分,本文主要介绍了关于Python代码缩进和测试模块的相关内容,分享出来供大家参考学习,下面话不多说了,来一起看看...

Python显示进度条的方法

Python显示进度条的方法

本文实例讲述了Python显示进度条的方法,是Python程序设计中非常实用的技巧。分享给大家供大家参考。具体方法如下: 首先,进度条和一般的print区别在哪里呢? 答案就是print...

Python实现合并excel表格的方法分析

本文实例讲述了Python实现合并excel表格的方法。分享给大家供大家参考,具体如下: 需求 将一个文件夹中的excel表格合并成我们想要的形式,主要要pandas中的concat()...

Python基于matplotlib画箱体图检验异常值操作示例【附xls数据文件下载】

Python基于matplotlib画箱体图检验异常值操作示例【附xls数据文件下载】

本文实例讲述了Python基于matplotlib画箱体图检验异常值操作。分享给大家供大家参考,具体如下: # -*- coding:utf-8 -*- #! python3 imp...