tensorflow中tf.app.flags.FLAGS用法介绍 |
您所在的位置:网站首页 › edd次元壁风动漫 › tensorflow中tf.app.flags.FLAGS用法介绍 |
tf 中定义了 tf.app.flags.FLAGS ,用于接受从终端传入的命令行参数,相当于对python中的命令行参数模块optpars做了一层封装。 例:#coding:utf-8 # 学习使用 tf.app.flags 使用,全局变量 # 可以再命令行中运行也是比较方便,如果只写 python app_flags.py 则代码运行时默认程序里面设置的默认设置 # 若 python app_flags.py --train_data_path --max_sentence_len 100 # --embedding_size 100 --learning_rate 0.05 代码再执行的时候将会按照上面的参数来运行程序 import tensorflow as tf FLAGS = tf.app.flags.FLAGS # tf.app.flags.DEFINE_string("param_name", "default_val", "description") tf.app.flags.DEFINE_string("train_data_path", "/home/yongcai/chinese_fenci/train.txt", "training data dir") tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir") tf.app.flags.DEFINE_integer("max_sentence_len", 80, "max num of tokens per query") tf.app.flags.DEFINE_integer("embedding_size", 50, "embedding size") tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate") def main(unused_argv): train_data_path = FLAGS.train_data_path print("train_data_path", train_data_path) max_sentence_len = FLAGS.max_sentence_len print("max_sentence_len", max_sentence_len) embdeeing_size = FLAGS.embedding_size print("embedding_size", embdeeing_size) abc = tf.add(max_sentence_len, embdeeing_size) init = tf.global_variables_initializer() #with tf.Session() as sess: #sess.run(init) #print("abc", sess.run(abc)) sv = tf.train.Supervisor(logdir=FLAGS.log_dir, init_op=init) with sv.managed_session() as sess: print("abc:", sess.run(abc)) # sv.saver.save(sess, "/home/yongcai/tmp/") # 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数 if __name__ == '__main__': tf.app.run() # 解析命令行参数,调用main 函数 main(sys.argv)调用方法: 其中参数可以根据需求进行修改。 python app_flags.py --train_data_path --max_sentence_len 100 --embedding_size 100 --learning_rate 0.05如果这样调用: python app_flags.py则会执行程序时会自动调用程序中 default 中的参数。 解释和optpars中的参数类型类似是通过参数 “type=xxx” 定义的,tf中每个合法类型都有对应的 “DEFINE_xxx”函数。常用: tf.app.flags.DEFINE_string() :定义一个用于接收string类型数值的变量;tf.app.flags.DEFINE_integer() : 定义一个用于接收int类型数值的变量;tf.app.flags.DEFINE_float() : 定义一个用于接收float类型数值的变量;tf.app.flags.DEFINE_boolean() : 定义一个用于接收bool类型数值的变量;“DEFINE_xxx”函数带3个参数,分别是变量名称,默认值,用法描述,例如: tf.app.flags.DEFINE_string('ckpt_path', 'model/model.ckpt-100000', '''Checkpoint directory to restore''')下面给一个完整的例子定义一个名称是 "ckpt_path" 的变量,默认值是 ckpt_path = 'model/model.ckpt-100000',描述信息表明这是一个用于保存节点信息的路径。 # -*- coding=utf-8 -*- import tensorflow as tf FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('ckpt_path', 'model/model.ckpt-100000', '''模型保存路径''') tf.app.flags.DEFINE_float('learning_rate',0.0001,'''初始学习率''') tf.app.flags.DEFINE_integer('train_steps', 50000, '''总的训练轮数''') tf.app.flags.DEFINE_boolean('is_use_gpu', False, '''是否使用GPU''') print '模型保存路径: {}'.format(FLAGS.ckpt_path) print '初始学习率: {}'.format(FLAGS.learning_rate) print '总的训练次数: {}'.format(FLAGS.train_steps) print '是否使用GPU: {}'.format(FLAGS.is_use_gpu)使用 '-h' 指令查看帮助信息: python flags_test.py -h![]() 按默认设置执行程序: ![]() 传入用户自定义的命令行参数: python flags_test.py --ckpt_path abc/cba --learning_rate 0.001 --train_steps 10000 --is_use_gpu True![]() |
今日新闻 |
推荐新闻 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |