GeophyAI

Python与地球物理数据处理

0%

styleGAN源码解读之pretrained-example.py

styleGAN—-pretrained-example.py

此代码使用预训练的styleGAN模型karras2019stylegan-ffhq-1024x1024.pkl(百度网盘,密码1q60)生成器生成单张随机人脸图像。(如有需要请下载源代码)

源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import pickle
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import config
import sys

def main():
# Initialize TensorFlow.
tflib.init_tf()
# Load pre-trained network.
url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
_G, _D, Gs = pickle.load(f)
# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.

# Print network details.
#Gs.print_layers()

# Pick latent vector.
rnd = np.random.RandomState(5)
latents = rnd.randn(1, Gs.input_shape[1])
# Generate image.
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)

# Save image.
os.makedirs(config.result_dir, exist_ok=True)
png_filename = os.path.join(config.result_dir, 'example.png')
PIL.Image.fromarray(images[0], 'RGB').save(png_filename)

if __name__ == "__main__":
main()

微调预训练模型路径后的代码

由于代码中所使用的预训练模型存储在GoogleDrive上可能无法通过代码下载,因此我们提前下载下来并从本地读取。这里修改代码为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#这段代码运行在styleGAN根目录下的./ipynb/sytle_test.ipynb中
import os #module 'os' from '/xxx/Anaconda/envs/xxx/lib/python3.6/os.py'
import pickle # module 'pickle' from '/xxx/Anaconda/envs/xxx/lib/python3.x/pickle.py'
import numpy as np #略
import PIL.Image #略
import dnnlib # ./dnnlib
import dnnlib.tflib as tflib # module 'dnnlib.tflib' from '../dnnlib/tflib/__init__.py'
import config # <module 'config' from '../config.py'>
import sys # module 'sys' (built-in)
#输入print(module_name)可查看其文件位置()
#例如
#print(dnnlib)
#<module 'dnnlib' from '../dnnlib/__init__.py'>
def main():
# 初始化Tensorflow.
tflib.init_tf()
# 加载pkl文件,重构对象并导入预训练模型
url = r'/xxx/xxx/xxx/karras2019stylegan-ffhq-1024x1024.pkl'
with open(url, 'rb') as f:
_G, _D, Gs = pickle.load(f)
# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.

# 打印Gs网络架构,相当于tf中的.summary(),函数原型位于network.py中,是作者自定义的函数。
Gs.print_layers()

# 随机生成latent.
rnd = np.random.RandomState(5)
latents = rnd.randn(1, Gs.input_shape[1])
# 生成图像.
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)

# 保存图像.
os.makedirs(config.result_dir, exist_ok=True)
png_filename = os.path.join(config.result_dir, 'example.png')
PIL.Image.fromarray(images[0], 'RGB').save(png_filename)

if __name__ == "__main__":
main()

函数逐行分析

tflib.init_tf()

main函数中第一句tflib.init_tf()函数用于初始化tensorflow,此函数位于dnnlib/tflib/tfutil.py中,init_tf()所依赖的其它函数有_sanitize_tf_configcreate_session,这两个函数同样位于此文件中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# 本机环境为tenforlow-gpu1.14.0,运行源代码过程中经常提示版本相关的warning
# 所以这里使用tensorflow.compat.v1来运行
import tensorflow.compat.v1 as tf
#--- 初始化TF
def init_tf(config_dict: dict = None) -> None:
"""Initialize TensorFlow session using good default settings."""
# 如果已创建Seccion则退出该函数.
if tf.get_default_session() is not None:
return

# 调用_sanitize_tf_config
# 设置配置字典和随机种子.
cfg = _sanitize_tf_config(config_dict)
np_random_seed = cfg["rnd.np_random_seed"]
if np_random_seed is not None:
np.random.seed(np_random_seed)
tf_random_seed = cfg["rnd.tf_random_seed"]
if tf_random_seed == "auto":
tf_random_seed = np.random.randint(1 << 31)
if tf_random_seed is not None:
tf.set_random_seed(tf_random_seed)

# 设置TF环境变量.
for key, value in list(cfg.items()):
fields = key.split(".")
if fields[0] == "env":
assert len(fields) == 2
os.environ[fields[1]] = str(value)

# 调用create_session
# 创建TF会话.
create_session(cfg, force_as_default=True)

#--- 设置配置信息
def _sanitize_tf_config(config_dict: dict = None) -> dict:
# Defaults.
cfg = dict()
cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.

# User overrides.
if config_dict is not None:
cfg.update(config_dict)
return cfg

#--- 根据配置信息创建会话
def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
# """根据config_dict创建tf.Session."""
# 设置TF配置属性.
cfg = _sanitize_tf_config(config_dict)
config_proto = tf.ConfigProto()
for key, value in cfg.items():
fields = key.split(".")
if fields[0] not in ["rnd", "env"]:
obj = config_proto
for field in fields[:-1]:
obj = getattr(obj, field)
setattr(obj, fields[-1], value)
# 创建会话.
session = tf.Session(config=config_proto)
if force_as_default:
# pylint: disable=protected-access
# 将自己设置为default session
session._default_session = session.as_default()
session._default_session.enforce_nesting = False
session._default_session.__enter__() # pylint: disable=no-member

return session

以上函数中,def init_tf(config_dict:dict=None)->None定义了名为init_tf的函数:该函数接受名为config_dict的变量,config_dict:dict=None表示config_dict的期望类型为字典以及初始值为None;其中的冒号:是一种Type Annotation,其用于提示变量的类型为dict->用于注释函数的返回类型为None

if tf.get_default_session() is not None: return用于获取当前会话(Session),如果会话不为空则退出init_tf,若未创建会话则执行后续代码。

cfg=_sanitize_tf_config(config_dict)调用了_sanitize_tf_config函数,此函数中创建了一个参数字典,我们可以通过以下方式打印其中的值:

1
2
3
4
5
6
7
8
for key, value in zip(cfg.keys(), cfg.values()):
print(key, ':', value)
###--------结果如下--------###
rnd.np_random_seed : None
rnd.tf_random_seed : auto
env.TF_CPP_MIN_LOG_LEVEL : 1
graph_options.place_pruned_graph : True
gpu_options.allow_growth : True

在函数_sanitize_tf_config中,首先创建了名为cfg的字典:该字典中包含上个代码框中的默认key及其value,当该函数的输入变量config_dict不为空时则执行cfg.update(config_dict)将新的字典添加到cfg中,最后返回字典cfg

np_random_seed默认为None 。当tf_random_seed‘auto’(默认状态)时,执行tf_random_seed=np.random.randint(1 << 31),即生成一个小于1<<31=2147483648的值并赋给tf_random_seed,若为其它非‘auto’及非None值,则将其作为种子输入到tf.set_random_seed()中。

接下来的代码 key.split(".")以”.“分割了字典中的key并找到名为”env“的项,后续操作等价于os.environ['TF_CPP_MIN_LOG_LEVEL']=’1‘。即,将tflog输出级别设置为1。

create_session(cfg,force_as_default=True)将创建的配置信息以字典变量方式传入函数create_session并强制设置当前会话为默认会话。

obj=getattr(obj,field)获取了objgraph_optionsgpu_options属性并分别设置其中的place_pruned_graphallow_growthTrue.

Gs.print_layers()

回到主函数中,Gs.print_layers()结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
G                               Params    OutputShape          WeightShape     
--- --- --- ---
latents_in - (?, 512) -
labels_in - (?, 0) -
lod - () -
dlatent_avg - (512,) -
G_mapping/latents_in - (?, 512) -
G_mapping/labels_in - (?, 0) -
G_mapping/PixelNorm - (?, 512) -
G_mapping/Dense0 262656 (?, 512) (512, 512)
G_mapping/Dense1 262656 (?, 512) (512, 512)
G_mapping/Dense2 262656 (?, 512) (512, 512)
G_mapping/Dense3 262656 (?, 512) (512, 512)
G_mapping/Dense4 262656 (?, 512) (512, 512)
G_mapping/Dense5 262656 (?, 512) (512, 512)
G_mapping/Dense6 262656 (?, 512) (512, 512)
G_mapping/Dense7 262656 (?, 512) (512, 512)
G_mapping/Broadcast - (?, 18, 512) -
G_mapping/dlatents_out - (?, 18, 512) -
Truncation - (?, 18, 512) -
G_synthesis/dlatents_in - (?, 18, 512) -
G_synthesis/4x4/Const 534528 (?, 512, 4, 4) (512,)
G_synthesis/4x4/Conv 2885632 (?, 512, 4, 4) (3, 3, 512, 512)
G_synthesis/ToRGB_lod8 1539 (?, 3, 4, 4) (1, 1, 512, 3)
G_synthesis/8x8/Conv0_up 2885632 (?, 512, 8, 8) (3, 3, 512, 512)
G_synthesis/8x8/Conv1 2885632 (?, 512, 8, 8) (3, 3, 512, 512)
G_synthesis/ToRGB_lod7 1539 (?, 3, 8, 8) (1, 1, 512, 3)
G_synthesis/Upscale2D - (?, 3, 8, 8) -
G_synthesis/Grow_lod7 - (?, 3, 8, 8) -
G_synthesis/16x16/Conv0_up 2885632 (?, 512, 16, 16) (3, 3, 512, 512)
G_synthesis/16x16/Conv1 2885632 (?, 512, 16, 16) (3, 3, 512, 512)
G_synthesis/ToRGB_lod6 1539 (?, 3, 16, 16) (1, 1, 512, 3)
G_synthesis/Upscale2D_1 - (?, 3, 16, 16) -
G_synthesis/Grow_lod6 - (?, 3, 16, 16) -
G_synthesis/32x32/Conv0_up 2885632 (?, 512, 32, 32) (3, 3, 512, 512)
G_synthesis/32x32/Conv1 2885632 (?, 512, 32, 32) (3, 3, 512, 512)
G_synthesis/ToRGB_lod5 1539 (?, 3, 32, 32) (1, 1, 512, 3)
G_synthesis/Upscale2D_2 - (?, 3, 32, 32) -
G_synthesis/Grow_lod5 - (?, 3, 32, 32) -
G_synthesis/64x64/Conv0_up 1442816 (?, 256, 64, 64) (3, 3, 512, 256)
G_synthesis/64x64/Conv1 852992 (?, 256, 64, 64) (3, 3, 256, 256)
G_synthesis/ToRGB_lod4 771 (?, 3, 64, 64) (1, 1, 256, 3)
G_synthesis/Upscale2D_3 - (?, 3, 64, 64) -
G_synthesis/Grow_lod4 - (?, 3, 64, 64) -
G_synthesis/128x128/Conv0_up 426496 (?, 128, 128, 128) (3, 3, 256, 128)
G_synthesis/128x128/Conv1 279040 (?, 128, 128, 128) (3, 3, 128, 128)
G_synthesis/ToRGB_lod3 387 (?, 3, 128, 128) (1, 1, 128, 3)
G_synthesis/Upscale2D_4 - (?, 3, 128, 128) -
G_synthesis/Grow_lod3 - (?, 3, 128, 128) -
G_synthesis/256x256/Conv0_up 139520 (?, 64, 256, 256) (3, 3, 128, 64)
G_synthesis/256x256/Conv1 102656 (?, 64, 256, 256) (3, 3, 64, 64)
G_synthesis/ToRGB_lod2 195 (?, 3, 256, 256) (1, 1, 64, 3)
G_synthesis/Upscale2D_5 - (?, 3, 256, 256) -
G_synthesis/Grow_lod2 - (?, 3, 256, 256) -
G_synthesis/512x512/Conv0_up 51328 (?, 32, 512, 512) (3, 3, 64, 32)
G_synthesis/512x512/Conv1 42112 (?, 32, 512, 512) (3, 3, 32, 32)
G_synthesis/ToRGB_lod1 99 (?, 3, 512, 512) (1, 1, 32, 3)
G_synthesis/Upscale2D_6 - (?, 3, 512, 512) -
G_synthesis/Grow_lod1 - (?, 3, 512, 512) -
G_synthesis/1024x1024/Conv0_up 21056 (?, 16, 1024, 1024) (3, 3, 32, 16)
G_synthesis/1024x1024/Conv1 18752 (?, 16, 1024, 1024) (3, 3, 16, 16)
G_synthesis/ToRGB_lod0 51 (?, 3, 1024, 1024) (1, 1, 16, 3)
G_synthesis/Upscale2D_7 - (?, 3, 1024, 1024) -
G_synthesis/Grow_lod0 - (?, 3, 1024, 1024) -
G_synthesis/images_out - (?, 3, 1024, 1024) -
G_synthesis/lod - () -
G_synthesis/noise0 - (1, 1, 4, 4) -
G_synthesis/noise1 - (1, 1, 4, 4) -
G_synthesis/noise2 - (1, 1, 8, 8) -
G_synthesis/noise3 - (1, 1, 8, 8) -
G_synthesis/noise4 - (1, 1, 16, 16) -
G_synthesis/noise5 - (1, 1, 16, 16) -
G_synthesis/noise6 - (1, 1, 32, 32) -
G_synthesis/noise7 - (1, 1, 32, 32) -
G_synthesis/noise8 - (1, 1, 64, 64) -
G_synthesis/noise9 - (1, 1, 64, 64) -
G_synthesis/noise10 - (1, 1, 128, 128) -
G_synthesis/noise11 - (1, 1, 128, 128) -
G_synthesis/noise12 - (1, 1, 256, 256) -
G_synthesis/noise13 - (1, 1, 256, 256) -
G_synthesis/noise14 - (1, 1, 512, 512) -
G_synthesis/noise15 - (1, 1, 512, 512) -
G_synthesis/noise16 - (1, 1, 1024, 1024) -
G_synthesis/noise17 - (1, 1, 1024, 1024) -
images_out - (?, 3, 1024, 1024) -
--- --- --- ---
Total 26219627

np.random.RandomState

代码rnd = np.random.RandomState(seed)基于Mesenne Twister算法生成伪随机数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
>>>for i in range(5):
rnd = np.random.RandomState(None)
print(rnd.randint(1, 10, 5))
[output]::
[7 6 7 5 6]
[7 4 5 3 9]
[1 7 9 3 6]
[5 3 3 1 3]
[9 3 5 2 9]
>>>for i in range(5):
rnd = np.random.RandomState(1)
print(rnd.randint(1, 10, 5))
[output]::
[6 9 6 1 1]
[6 9 6 1 1]
[6 9 6 1 1]
[6 9 6 1 1]
[6 9 6 1 1]
>>>rnd = np.random.RandomState(1)
for i in range(5):
print(rnd.randint(1, 10, 5))
[output]::
[6 9 6 1 1]
[2 8 7 3 5]
[6 3 5 3 5]
[8 8 2 8 1]
[7 8 7 2 1]

latents = rnd.randn(1, Gs.input_shape[1])生成一个(1, 512)的随机向量。
fmt=dict(func=tflib.convert_images_to_uint8,nchw_to_nhwc=True)将函数tflib.convert_images_to_uint8以及boolnchw_to_nhwc=True打包成字典fmt

tflib.convert_images_to_uint8

1
2
3
4
5
6
7
8
9
10
11
12
"""将数据类型为float32的minibatch图像转换为uint8.以及调整数据格式.
可以用来对Network.run()的输出做转换处理."""
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
images = tf.cast(images, tf.float32)
if shrink > 1:
ksize = [1, 1, shrink, shrink]
images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
if nchw_to_nhwc:
images = tf.transpose(images, [0, 2, 3, 1])
scale = 255 / (drange[1] - drange[0]) # scale = 127.5
images = images * scale + (0.5 - drange[0] * scale)
return tf.saturate_cast(images, tf.uint8)

“NHWC”模式下,批次数据的排列顺序为[batch, height, width, channels](TF的默认设置)
“NCHW”模式下,批次数据的排列顺序为[batch, channels, height, width]
NHWCNCHW的解读可参照深度学习(4):NCHW和NHWC,后续个人将完善这部分的学习。

Gs.run

images=Gs.run(latents,None,truncation_psi=0.7,randomize_noise=True,output_transform=fmt)调用dnnlib/tflib/network.pyNetwork类的run。关于network.py的解读将放在其它文章中。

os.makedirs & os.path.join

1
2
os.makedirs(config.result_dir, exist_ok=True)
png_filename = os.path.join(config.result_dir,'example.png')

os.makedirs(config.result_dir, exist_ok=True)会根据config.result_dir提供的路径创建相应的文件夹,exist_ok=True用于当上述文件夹已经存在时屏蔽OSError
os.path.join(a)函数可以将多个路径分量(父路径以及文件名)通过分隔符'/'连接在一起,本例中将config.result_dir‘example.png’连接并赋值给png_filename,相当于png_filename="%s/%s"%(config.result_dir,'example.png')

PIL.Image.fromarray

PIL.Image.fromarray(images[0],'RGB').save(png_filename)imgages[0]RGB格式存储在文件png_filename中。