必威体育Betway必威体育官网
当前位置:首页 > IT技术

tf.train.MonitoredSession 简介

时间:2019-08-20 05:42:22来源:IT技术作者:seo实验室小编阅读:83次「手机版」
 

monitored

在run过程中的集成一些操作,比如输出log,保存,summary 等


基类一般用在infer阶段,训练阶段使用它的子类

tf.train.monitoredTrainingsession

1 MonitoredTrainingSession

1.1 构造函数

MonitoredTrainingSession(
    master='',
    is_chief=True,
    checkpoint_dir=None,
    scaffold=None,
    hooks=None,
    chief_only_hooks=None,
    save_checkpoint_secs=600,
    save_summaries_steps=USE_DEFAULT,
    save_summaries_secs=USE_DEFAULT,
    config=None,
    stop_grace_period_secs=120,
    log_step_count_steps=100,
    max_wait_secs=7200
)

官方例子

saver_hook = CheckpointSaverHook(...)
summary_hook = SummarySaverHook(...)
with MonitoredSession(session_creator=ChiefSessionCreator(...),
                      hooks=[saver_hook, summary_hook]) as sess:
  while not sess.should_stop():
    sess.run(train_op)

首先,当MonitoredSession初始化的时候,会按顺序执行下面操作:

  • 调用hook的begin()函数,我们一般在这里进行一些hook内的初始化。比如在上面猫狗大战中的_LoggerHook里面的_step属性,就是用来记录执行步骤的,但是该参数只在本类中起作用。
  • 通过调用scaffold.finalize()初始化计算图

    创建会话

  • 通过初始化Scaffold提供的操作(op)来初始化模型
  • 如果checkpoint存在的话,restore模型的参数
  • launches queue runners
  • 调用hook.after_create_session()

然后,当run()函数运行的时候,按顺序执行下列操作:

  • 调用hook.before_run()
  • 调用TensorFlow的 session.run()
  • 调用hook.after_run()
  • 返回用户需要的session.run()的结果
  • 如果发生了AbortedERROR或者UnavailableError,则在再次执行run()之前恢复或者重新初始化会话

最后,当调用close()退出时,按顺序执行下列操作:

  • 调用hook.end()
  • 关闭队列和会话
  • 阻止OutOfRange错误

1.2 Hook

所以这些钩子函数就是重点关注的对象

.1 LoggingTensorHook

tf.train.LoggingTensorHook 官方说明

prints the given tensors every N local steps, every N seconds, or at end.

__init__(
    tensors,
    every_n_iter=None,
    every_n_secs=None,
    formatter=None
)
  • tensors: dict that maps string-valued tags to tensors/tensor names, or iterable of tensors/tensor names.

用法举例

# Set up logging for predictions
  tensors_to_log = {"probabilities": "softmax_tensor"}
  logging_hook = tf.train.LoggingTensorHook(
      tensors=tensors_to_log, every_n_iter=50)

.2 SummarySaverHook

tf.train.SummarySaverHook

Saves summaries every N steps

__init__(
    save_steps=None,
    save_secs=None,
    output_dir=None,
    summary_writer=None,
    scaffold=None,
    summary_op=None
)

output_dir 填 路径

summary_op 填 tf.summary.merge_all

.3 CheckpointSaverHook

tf.train.CheckpointSaverHook

MonitoredTrainingSession 只有 save_checkpoint_secs, 没有按step保存的选项

* Saves checkpoints every N steps or seconds

__init__(
    checkpoint_dir,
    save_secs=None,
    save_steps=None,
    saver=None,
    checkpoint_basename='model.ckpt',
    scaffold=None,
    listeners=None
)

必填 saver, save_secs 或者 save_steps

.4 NanTensorHook

tf.train.NanTensorHook

感觉是用来调试的,加到训练过程中可能会拖慢train

  • Monitors the loss tensor and stops training if loss is NaN.

    Can either fail with exception or just stop training.

__init__(
    loss_tensor,
    fail_on_nan_loss=True
)

.5 FeedFnHook

tf.train.FeedFnHook

看着像用来产生 feed_dict

Runs feed_fn and sets the feed_dict accordingly

__init__(feed_fn)

.6 GlobalStepWaiterHook

tf.train.GlobalStepWaiterHook

分布式用

.7 profilerHook

tf.train.ProfilerHook

This hook delays execution until global step reaches to wait_until_step. It is used to gradually start workers in distributed settings. One example usage would be setting wait_until_step=int(K*log(task_id+1)) assuming that task_id=0 is the chief

reference

tf.train.MonitoredSession

https://www.tensorflow.org/versions/master/api_docs/Python/tf/train/MonitoredSession

resnet_main.py

https://github.com/tensorflow/models/blob/master/research/resnet/resnet_main.py

tf.train.MonitoredTrainingSession

https://www.tensorflow.org/versions/master/api_docs/python/tf/train/MonitoredTrainingSession

使用自己的数据集进行一次完整的TensorFlow训练

https://zhuanlan.zhihu.com/p/32490882

tf.train.LoggingTensorHook

https://www.tensorflow.org/api_docs/python/tf/train/LoggingTensorHook

tf.train.SummarySaverHook

https://www.tensorflow.org/versions/master/api_docs/python/tf/train/SummarySaverHook

tf.train.CheckpointSaverHook

https://www.tensorflow.org/versions/master/api_docs/python/tf/train/CheckpointSaverHook

tf.train.NanTensorHook

https://www.tensorflow.org/versions/master/api_docs/python/tf/train/NanTensorHook#__init__

相关阅读

LevelDB简介

LSM简介背景介绍levelDB整体数据流图数据写入内存:内存数据写入文件流程ssTable文件满足压缩的条件压缩文件的筛选压缩--简单易懂

Linux信号量 sem_t简介

函数介绍#include<semaphore.h>信号量的数据类型为结构sem_t,它本质上是一个长整型的数。函数sem_init()用来初始化一个信号量。它

HTTP解析库http-parser简介及使用

http-parser是一个用C编写的HTTP消息解析器,可以解析请求和响应,被设计用于高性能HTTP应用程序。它不会进行任何系统调用及内存分配

Py之curses:curses库的简介、使用、安装方法详细攻略

Py之curses:curses库的简介、使用、安装方法详细攻略 目录 curses库简介 curses库安装 T1、直接命令法 T2、下载whl法 curses库

day01 -云计算简介与华为云计算解决方案

第一章:华为云计算解决方案       本章主要讲述了华为云计算的不同解决方案。介绍了服务器虚拟化、数据中心、桌面云、公有云

分享到:

栏目导航

推荐阅读

热门阅读