0%

SPGAN-tensorflow

在阅读SPGAN代码源码的过程中,学习到的关于tensorflow的一些知识。

1. tf.ConfigProto

参考链接:https://blog.csdn.net/dcrmg/article/details/79091941

tf.ConfigProto用于对sessison会话的参数配置。

  • log_device_placement=True: 可以获取到 operations 和 Tensor 被指派到哪个设备(几号CPU或几号GPU)上运行,会在终端打印出各项操作是在哪个设备上运行的
  • allow_soft_placement=True: 允许tf自动选择一个存在并且可用的设备来运行操作.在tf中,通过命令 “with tf.device(‘/cpu:0’):”,允许手动设置操作运行的设备
  • config.gpu_options.allow_growth = True: 动态申请显存
  • config.gpu_options.per_process_gpu_memory_fraction = 0.4: 占用40%显存,限制GPU使用率.
1
2
3
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
1
2
3
4
5
6
7
8
# 限制GPU使用率
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.4 #占用40%显存
session = tf.Session(config=config)
等同于
gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.4)
config=tf.ConfigProto(gpu_options=gpu_options)
session = tf.Session(config=config)
1
2
3
4
5
6
# 设置使用哪块GPU
方法一: 在python中设置
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #使用 GPU 0
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 使用 GPU 0,1
方法二: 在执行时设置
CUDA_VISIBLE_DEVICES=0,1 python yourcode.py

2. tf读取数据集图片的方式

参考链接:

https://www.jb51.net/article/134550.htm

https://www.jb51.net/article/134547.htm

tf的流程是文件系统—>文件名队列—>内存队列

推荐使用方法一

方法一:使用WholeFileReader输入queue,decode输出是Tensor,eval后是ndarray

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import tensorflow as tf
import os
import matplotlib.pyplot as plt
def file_name(file_dir): #来自//www.jb51.net/article/134543.htm
for root, dirs, files in os.walk(file_dir): #模块os中的walk()函数遍历文件夹下所有的文件
print(root) #当前目录路径
print(dirs) #当前路径下所有子目录
print(files) #当前路径下所有非目录子文件

def file_name2(file_dir): #特定类型的文件
L=[]
for root, dirs, files in os.walk(file_dir):
for file in files:
if os.path.splitext(file)[1] == '.jpg':
L.append(os.path.join(root, file))
return L

path = file_name2('test')

#以下参考//www.jb51.net/article/134547.htm (十图详解TensorFlow数据读取机制)
#path2 = tf.train.match_filenames_once(path)
file_queue = tf.train.string_input_producer(paths, shuffle=True, num_epochs=2) #创建输入队列
image_reader = tf.WholeFileReader()
key, image = image_reader.read(file_queue)
image = tf.image.decode_jpeg(image, channerls=3)

with tf.Session() as sess:
# coord = tf.train.Coordinator() #协同启动的线程
# threads = tf.train.start_queue_runners(sess=sess, coord=coord) #启动线程运行队列
# coord.request_stop() #停止所有的线程
# coord.join(threads)

tf.local_variables_initializer().run()
threads = tf.train.start_queue_runners(sess=sess)

#print (type(image))
#print (type(image.eval()))
#print(image.eval().shape)
for _ in path+path:
plt.figure
plt.imshow(image.eval())
plt.show()

方法二:使用gfile读图片,decode输出是Tensor,eval后是ndarray

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

print(tf.__version__)

image_raw = tf.gfile.FastGFile('test/a.jpg','rb').read()  #bytes
img = tf.image.decode_jpeg(image_raw) #Tensor
#img2 = tf.image.convert_image_dtype(img, dtype = tf.uint8)

with tf.Session() as sess:
  print(type(image_raw)) # bytes
  print(type(img)) # Tensor
  #print(type(img2))

  print(type(img.eval())) # ndarray !!!
  print(img.eval().shape)
  print(img.eval().dtype)

#  print(type(img2.eval()))
#  print(img2.eval().shape)
#  print(img2.eval().dtype)
  plt.figure(1)
  plt.imshow(img.eval())
  plt.show()

方法三:使用read_file,decode输出是Tensor,eval后是ndarray

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np

print(tf.__version__)

image_raw = tf.gfile.FastGFile('test/a.jpg','rb').read()  #bytes
img = tf.image.decode_jpeg(image_raw) #Tensor
#img2 = tf.image.convert_image_dtype(img, dtype = tf.uint8)

with tf.Session() as sess:
  print(type(image_raw)) # bytes
  print(type(img)) # Tensor
  #print(type(img2))

  print(type(img.eval())) # ndarray !!!
  print(img.eval().shape)
  print(img.eval().dtype)

#  print(type(img2.eval()))
#  print(img2.eval().shape)
#  print(img2.eval().dtype)
  plt.figure(1)
  plt.imshow(img.eval())
  plt.show()

3. tf.train.shuffle_batch

参考链接:

https://www.jianshu.com/p/9cfe9cadde06

https://blog.csdn.net/ying86615791/article/details/73864381

1
2
3
4
5

img_batch = tf.train.shuffle_batch([img],
batch_size=batch_size, capacity=capacity,
min_after_dequeue=min_after_dequeue,num_threads=num_threads,
allow_smaller_final_batch=allow_smaller_final_batch)

tf不是像pytorch一样全局打乱,而是每一次在较短的队列中打乱。其中,队列的最长长度是capacity,最短长度是min_after_dequeue。

4. tf.summary

参考链接:https://blog.csdn.net/hongxue8888/article/details/78610305

1
summary_writer = tf.summary.FileWriter('./summaries/' + dataset + '_spgan' , sess.graph)

5. tf.train.Saver

参考链接:http://www.cnblogs.com/denny402/p/6940134.html

1
saver = tf.train.Saver(max_to_keep= 30)

6. saver.restore

参考链接:https://blog.csdn.net/changeforeve/article/details/80268522

1
2
3
4
5
6
7
8
9
10
11
12
13
def load_checkpoint(checkpoint_dir, sess, saver):
print(" [*] Loading checkpoint...")
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
print(ckpt)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
ckpt_path = os.path.join(checkpoint_dir, ckpt_name)
saver.restore(sess, ckpt_path)
print(" [*] Loading successful!")
return ckpt_path
else:
print(" [*] No suitable checkpoint!")
return None

7. tf.train.Coordinator

参考链接:

https://blog.csdn.net/weixin_42052460/article/details/80714539

https://www.jianshu.com/p/d063804fb272

1
2
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

8. tf.identity

参考链接:

https://stackoverflow.com/questions/34877523/in-tensorflow-what-is-tf-identity-used-for

https://blog.csdn.net/hu_guan_jie/article/details/78495297

tf.idenity的逻辑就是等于号,区别是前者在计算图上加了个节点,使得可以多个设备之间可以通信,但是等于号为什么不行呢?

9. tf.reuse

参考链接:https://blog.csdn.net/UESTC_C2_403/article/details/72329786