在阅读SPGAN代码源码的过程中,学习到的关于tensorflow的一些知识。
1. tf.ConfigProto 参考链接:https://blog.csdn.net/dcrmg/article/details/79091941
tf.ConfigProto用于对sessison会话的参数配置。
log_device_placement=True: 可以获取到 operations 和 Tensor 被指派到哪个设备(几号CPU或几号GPU)上运行,会在终端打印出各项操作是在哪个设备上运行的
allow_soft_placement=True: 允许tf自动选择一个存在并且可用的设备来运行操作.在tf中,通过命令 “with tf.device(‘/cpu:0’):”,允许手动设置操作运行的设备
config.gpu_options.allow_growth = True: 动态申请显存
config.gpu_options.per_process_gpu_memory_fraction = 0.4: 占用40%显存,限制GPU使用率.
1 2 3 config = tf.ConfigProto(allow_soft_placement=True ) config.gpu_options.allow_growth = True sess = tf.Session(config=config)
1 2 3 4 5 6 7 8 config = tf.ConfigProto() config.gpu_options.per_process_gpu_memory_fraction = 0.4 session = tf.Session(config=config) 等同于 gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.4 ) config=tf.ConfigProto(gpu_options=gpu_options) session = tf.Session(config=config)
1 2 3 4 5 6 方法一: 在python中设置 os.environ['CUDA_VISIBLE_DEVICES' ] = '0' os.environ['CUDA_VISIBLE_DEVICES' ] = '0,1' 方法二: 在执行时设置 CUDA_VISIBLE_DEVICES=0 ,1 python yourcode.py
2. tf读取数据集图片的方式 参考链接:
https://www.jb51.net/article/134550.htm
https://www.jb51.net/article/134547.htm
tf的流程是文件系统—>文件名队列—>内存队列
推荐使用方法一
方法一:使用WholeFileReader输入queue,decode输出是Tensor,eval后是ndarray 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 import tensorflow as tfimport osimport matplotlib.pyplot as pltdef file_name (file_dir ): for root, dirs, files in os.walk(file_dir): print (root) print (dirs) print (files) def file_name2 (file_dir ): L=[] for root, dirs, files in os.walk(file_dir): for file in files: if os.path.splitext(file)[1 ] == '.jpg' : L.append(os.path.join(root, file)) return L path = file_name2('test' ) file_queue = tf.train.string_input_producer(paths, shuffle=True , num_epochs=2 ) image_reader = tf.WholeFileReader() key, image = image_reader.read(file_queue) image = tf.image.decode_jpeg(image, channerls=3 ) with tf.Session() as sess: tf.local_variables_initializer().run() threads = tf.train.start_queue_runners(sess=sess) for _ in path+path: plt.figure plt.imshow(image.eval ()) plt.show()
方法二:使用gfile读图片,decode输出是Tensor,eval后是ndarray 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 import matplotlib.pyplot as pltimport tensorflow as tfimport numpy as npprint (tf.__version__)image_raw = tf.gfile.FastGFile('test/a.jpg' ,'rb' ).read() img = tf.image.decode_jpeg(image_raw) with tf.Session() as sess: print (type (image_raw)) print (type (img)) print (type (img.eval ())) print (img.eval ().shape) print (img.eval ().dtype) plt.figure(1 ) plt.imshow(img.eval ()) plt.show()
方法三:使用read_file,decode输出是Tensor,eval后是ndarray 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 import matplotlib.pyplot as pltimport tensorflow as tfimport numpy as npprint (tf.__version__)image_raw = tf.gfile.FastGFile('test/a.jpg' ,'rb' ).read() img = tf.image.decode_jpeg(image_raw) with tf.Session() as sess: print (type (image_raw)) print (type (img)) print (type (img.eval ())) print (img.eval ().shape) print (img.eval ().dtype) plt.figure(1 ) plt.imshow(img.eval ()) plt.show()
3. tf.train.shuffle_batch 参考链接:
https://www.jianshu.com/p/9cfe9cadde06
https://blog.csdn.net/ying86615791/article/details/73864381
1 2 3 4 5 img_batch = tf.train.shuffle_batch([img], batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue,num_threads=num_threads, allow_smaller_final_batch=allow_smaller_final_batch)
tf不是像pytorch一样全局打乱,而是每一次在较短的队列中打乱。其中,队列的最长长度是capacity,最短长度是min_after_dequeue。
4. tf.summary 参考链接:https://blog.csdn.net/hongxue8888/article/details/78610305
1 summary_writer = tf.summary.FileWriter('./summaries/' + dataset + '_spgan' , sess.graph)
5. tf.train.Saver 参考链接:http://www.cnblogs.com/denny402/p/6940134.html
1 saver = tf.train.Saver(max_to_keep= 30 )
6. saver.restore 参考链接:https://blog.csdn.net/changeforeve/article/details/80268522
1 2 3 4 5 6 7 8 9 10 11 12 13 def load_checkpoint (checkpoint_dir, sess, saver ): print (" [*] Loading checkpoint..." ) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) print (ckpt) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) ckpt_path = os.path.join(checkpoint_dir, ckpt_name) saver.restore(sess, ckpt_path) print (" [*] Loading successful!" ) return ckpt_path else : print (" [*] No suitable checkpoint!" ) return None
7. tf.train.Coordinator 参考链接:
https://blog.csdn.net/weixin_42052460/article/details/80714539
https://www.jianshu.com/p/d063804fb272
1 2 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord)
8. tf.identity 参考链接:
https://stackoverflow.com/questions/34877523/in-tensorflow-what-is-tf-identity-used-for
https://blog.csdn.net/hu_guan_jie/article/details/78495297
tf.idenity的逻辑就是等于号,区别是前者在计算图上加了个节点,使得可以多个设备之间可以通信,但是等于号为什么不行呢?
9. tf.reuse 参考链接:https://blog.csdn.net/UESTC_C2_403/article/details/72329786