styleGAN—-pretrained-example.py
此代码使用预训练的styleGAN模型karras2019stylegan-ffhq-1024x1024.pkl(百度网盘,密码1q60)生成器生成单张随机人脸图像。(如有需要请下载源代码
源码
1 | import os |
微调预训练模型路径后的代码
由于代码中所使用的预训练模型存储在GoogleDrive上可能无法通过代码下载,因此我们提前下载下来并从本地读取。这里修改代码为:
1 | #这段代码运行在styleGAN根目录下的./ipynb/sytle_test.ipynb中 |
函数逐行分析
tflib.init_tf()
main函数中第一句tflib.init_tf()函数用于初始化tensorflow,此函数位于dnnlib/tflib/tfutil.py中,init_tf()所依赖的其它函数有_sanitize_tf_config和create_session,这两个函数同样位于此文件中。
1 | # 本机环境为tenforlow-gpu1.14.0,运行源代码过程中经常提示版本相关的warning |
以上函数中,def init_tf(config_dict:dict=None)->None定义了名为init_tf的函数:该函数接受名为config_dict的变量,config_dict:dict=None表示config_dict的期望类型为字典以及初始值为None;其中的冒号:是一种Type Annotation,其用于提示变量的类型为dict;->用于注释函数的返回类型为None。
if tf.get_default_session() is not None: return用于获取当前会话(Session),如果会话不为空则退出init_tf,若未创建会话则执行后续代码。
cfg=_sanitize_tf_config(config_dict)调用了_sanitize_tf_config函数,此函数中创建了一个参数字典,我们可以通过以下方式打印其中的值:
1 | for key, value in zip(cfg.keys(), cfg.values()): |
在函数_sanitize_tf_config中,首先创建了名为cfg的字典:该字典中包含上个代码框中的默认key及其value,当该函数的输入变量config_dict不为空时则执行cfg.update(config_dict)将新的字典添加到cfg中,最后返回字典cfg。
np_random_seed默认为None 。当tf_random_seed为‘auto’(默认状态)时,执行tf_random_seed=np.random.randint(1 << 31),即生成一个小于1<<31=2147483648的值并赋给tf_random_seed,若为其它非‘auto’及非None值,则将其作为种子输入到tf.set_random_seed()中。
接下来的代码 key.split(".")以”.“分割了字典中的key并找到名为”env“的项,后续操作等价于os.environ['TF_CPP_MIN_LOG_LEVEL']=’1‘。即,将tf的log输出级别设置为1。
create_session(cfg,force_as_default=True)将创建的配置信息以字典变量方式传入函数create_session并强制设置当前会话为默认会话。
obj=getattr(obj,field)获取了obj的graph_options和gpu_options属性并分别设置其中的place_pruned_graph和allow_growth为True.
Gs.print_layers()
回到主函数中,Gs.print_layers()结果:
1 | G Params OutputShape WeightShape |
np.random.RandomState
代码rnd = np.random.RandomState(seed)基于Mesenne Twister算法生成伪随机数:
1 | >>>for i in range(5): |
latents = rnd.randn(1, Gs.input_shape[1])生成一个(1, 512)的随机向量。
fmt=dict(func=tflib.convert_images_to_uint8,nchw_to_nhwc=True)将函数tflib.convert_images_to_uint8以及bool值nchw_to_nhwc=True打包成字典fmt。
tflib.convert_images_to_uint8
1 | """将数据类型为float32的minibatch图像转换为uint8.以及调整数据格式. |
“NHWC”模式下,批次数据的排列顺序为[batch, height, width, channels](TF的默认设置)
“NCHW”模式下,批次数据的排列顺序为[batch, channels, height, width]
NHWC和NCHW的解读可参照深度学习(4):NCHW和NHWC,后续个人将完善这部分的学习。
Gs.run
images=Gs.run(latents,None,truncation_psi=0.7,randomize_noise=True,output_transform=fmt)调用dnnlib/tflib/network.py中Network类的run。关于network.py的解读将放在其它文章中。
os.makedirs & os.path.join
1 | os.makedirs(config.result_dir, exist_ok=True) |
os.makedirs(config.result_dir, exist_ok=True)会根据config.result_dir提供的路径创建相应的文件夹,exist_ok=True用于当上述文件夹已经存在时屏蔽OSError 。os.path.join(a)函数可以将多个路径分量(父路径以及文件名)通过分隔符'/'连接在一起,本例中将config.result_dir与‘example.png’连接并赋值给png_filename,相当于png_filename="%s/%s"%(config.result_dir,'example.png')。
PIL.Image.fromarray
PIL.Image.fromarray(images[0],'RGB').save(png_filename)将imgages[0]以RGB格式存储在文件png_filename中。