tensorflow中tf.app.flags.FLAGS用法介绍

您所在的位置:网站首页 edd次元壁风动漫 tensorflow中tf.app.flags.FLAGS用法介绍

tensorflow中tf.app.flags.FLAGS用法介绍

2023-11-22 11:18| 来源: 网络整理| 查看: 265

tf 中定义了 tf.app.flags.FLAGS ,用于接受从终端传入的命令行参数,相当于对python中的命令行参数模块optpars做了一层封装。

例:#coding:utf-8   # 学习使用 tf.app.flags 使用,全局变量 # 可以再命令行中运行也是比较方便,如果只写 python app_flags.py 则代码运行时默认程序里面设置的默认设置 # 若 python app_flags.py --train_data_path --max_sentence_len 100 #    --embedding_size 100 --learning_rate 0.05  代码再执行的时候将会按照上面的参数来运行程序   import tensorflow as tf   FLAGS = tf.app.flags.FLAGS   # tf.app.flags.DEFINE_string("param_name", "default_val", "description") tf.app.flags.DEFINE_string("train_data_path", "/home/yongcai/chinese_fenci/train.txt", "training data dir") tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir") tf.app.flags.DEFINE_integer("max_sentence_len", 80, "max num of tokens per query") tf.app.flags.DEFINE_integer("embedding_size", 50, "embedding size")  tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate")     def main(unused_argv):     train_data_path = FLAGS.train_data_path     print("train_data_path", train_data_path)     max_sentence_len = FLAGS.max_sentence_len     print("max_sentence_len", max_sentence_len)     embdeeing_size = FLAGS.embedding_size     print("embedding_size", embdeeing_size)     abc = tf.add(max_sentence_len, embdeeing_size)       init = tf.global_variables_initializer()       #with tf.Session() as sess:         #sess.run(init)         #print("abc", sess.run(abc))       sv = tf.train.Supervisor(logdir=FLAGS.log_dir, init_op=init)     with sv.managed_session() as sess:         print("abc:", sess.run(abc))           # sv.saver.save(sess, "/home/yongcai/tmp/")     # 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数 if __name__ == '__main__':     tf.app.run()   # 解析命令行参数,调用main 函数 main(sys.argv)

调用方法:

其中参数可以根据需求进行修改。

python app_flags.py --train_data_path --max_sentence_len 100 --embedding_size 100 --learning_rate 0.05

如果这样调用:

python app_flags.py

则会执行程序时会自动调用程序中 default 中的参数。

解释

和optpars中的参数类型类似是通过参数 “type=xxx” 定义的,tf中每个合法类型都有对应的 “DEFINE_xxx”函数。常用:

tf.app.flags.DEFINE_string() :定义一个用于接收string类型数值的变量;tf.app.flags.DEFINE_integer() : 定义一个用于接收int类型数值的变量;tf.app.flags.DEFINE_float() : 定义一个用于接收float类型数值的变量;tf.app.flags.DEFINE_boolean() : 定义一个用于接收bool类型数值的变量;

“DEFINE_xxx”函数带3个参数,分别是变量名称,默认值,用法描述,例如:

tf.app.flags.DEFINE_string('ckpt_path', 'model/model.ckpt-100000', '''Checkpoint directory to restore''')下面给一个完整的例子

定义一个名称是 "ckpt_path" 的变量,默认值是 ckpt_path = 'model/model.ckpt-100000',描述信息表明这是一个用于保存节点信息的路径。

# -*- coding=utf-8 -*- import tensorflow  as tf FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('ckpt_path', 'model/model.ckpt-100000', '''模型保存路径''') tf.app.flags.DEFINE_float('learning_rate',0.0001,'''初始学习率''') tf.app.flags.DEFINE_integer('train_steps', 50000, '''总的训练轮数''') tf.app.flags.DEFINE_boolean('is_use_gpu', False, '''是否使用GPU''') print '模型保存路径: {}'.format(FLAGS.ckpt_path) print '初始学习率: {}'.format(FLAGS.learning_rate) print '总的训练次数: {}'.format(FLAGS.train_steps) print '是否使用GPU: {}'.format(FLAGS.is_use_gpu)

使用 '-h' 指令查看帮助信息:

python flags_test.py -h

按默认设置执行程序:

传入用户自定义的命令行参数:

python flags_test.py --ckpt_path abc/cba --learning_rate 0.001 --train_steps 10000 --is_use_gpu True


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3