GeophyAI

Python与地球物理数据处理

0%

styleGAN源码解读之network.py(一)

styleGAN—-network.py

源码

源码文件位于dnnlib/tflib中过长,将近600行,这里只放出Network类中部分函数,当存在函数调用关系时再给出相应的源码及其解读。(在代码的复制粘贴过程中源代码格式遭到破坏(排版出现问题),请勿直接复制以下代码,如有需要请下载源代码。另外,文中对run方法的测试均是通过调用pretrained.py后得到的结果,在运行不同配置train.py时得到的结果可能会略有差异,一般体现在minibatch大小上。)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
class Network:
def run(self,
*in_arrays: Tuple[Union[np.ndarray, None], ...],
input_transform: dict = None,
output_transform: dict = None,
return_as_list: bool = False,
print_progress: bool = False,
minibatch_size: int = None,
num_gpus: int = 1,
assume_frozen: bool = False,
**dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:

assert len(in_arrays) == self.num_inputs
assert not all(arr is None for arr in in_arrays)
assert input_transform is None or util.is_top_level_function(input_transform["func"])
assert output_transform is None or util.is_top_level_function(output_transform["func"])
output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
num_items = in_arrays[0].shape[0]
if minibatch_size is None:
minibatch_size = num_items

# Construct unique hash key from all arguments that affect the TensorFlow graph.
key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
def unwind_key(obj):
if isinstance(obj, dict):
return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
if callable(obj):
return util.get_top_level_function_name(obj)
return obj
key = repr(unwind_key(key))

# Build graph.
if key not in self._run_cache:
with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
with tf.device("/cpu:0"):
in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))

out_split = []
for gpu in range(num_gpus):
with tf.device("/gpu:%d" % gpu):
net_gpu = self.clone() if assume_frozen else self
in_gpu = in_split[gpu]

if input_transform is not None:
in_kwargs = dict(input_transform)
in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)

assert len(in_gpu) == self.num_inputs
out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)

if output_transform is not None:
out_kwargs = dict(output_transform)
out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)

assert len(out_gpu) == self.num_outputs
out_split.append(out_gpu)

with tf.device("/cpu:0"):
out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
self._run_cache[key] = in_expr, out_expr

# Run minibatches.

# Done.

return out_arrays

代码微调

如果你的机器环境为单节点单卡,请把以上函数中注释#Build graph以下的所有 with tf.device()设置为你的GPU,例如with tf.device("/gpu:0")(若只有一张显卡,其对应的设备号为0),以下代码可以查看当前环境下可以使用的CPU/GPU设备号:

1
2
>>> from tensorflow.python.client import device_lib
>>> print(device_lib.list_local_devices())

代码逐行分析

assert

1
2
3
4
assert len(in_arrays) == self.num_inputs
assert not all(arr is None for arr in in_arrays)
assert input_transform is None or util.is_top_level_function(input_transform["func"])
assert output_transform is None or util.is_top_level_function(output_transform["func"])

我们首先看一下assert以及allnot all的用法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# exampe1--assert
>>> assert 0,'请检查输入'
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AssertionError: 请检查输入
>>> assert 1,'pass' # 无输出
>>> assert True or False # 无输出
# example2--all, not all
"""
Signature: all(iterable, /)
Docstring:
若可迭代变量iterable中所有的值均为True,则all(iterable)返回True。
若iterable为空,也返回True。
"""
>>> all([0, 1, 2, 3])
False
>>> all([1, 1, 2, 3])
True
>>> not all([0, 1, 2, 3])
True

这里,第一行的assert len(in_arrays) == self.num_inputs首先判断输入是否与期望输入长度相等,in_arraystuplelen(in_arrays)=self.num_inputs=2in_arrays[0] 是一个size(1,512)np.arrayn_array[1]Nonetype
第二行的assert not all(arr is None for arr in in_arrays)将对元组in_arrays中的元素进行判断,若所有元素均为Nonetype则抛出AssertionError
第三和第四行分别对字典input_transform以及output_transform进行判断,若其为空或者并没有加载到内存中,则抛出AssertionError

is_top_level_function

or之后的函数原型为:

1
2
def is_top_level_function(obj: Any) -> bool:
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__

is_top_level_function用来判断obj是否为顶层函数,即是否是由def所定义,然后确认该函数是否已经读取到内存中。sys.modules是一个字典,加载到内存的模块将会以字典的形式存储在其中(第一次导入时自动记录所导入的模块为字典,第二导入则直接从字典中取出相应的键值)。(callable的应用可以查看链接, 它用于检查一个对象是否是可调用的)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

>>> import sys
>>> class Father(object):
... "sdfsfdfs"
... def __init__(self, name):
... self.name=name
... print ( "name: %s" %( self.name) )
... def getName(self):
... return 'Father ' + self.name
...
>>> sys.modules[Father.__module__].__dict__
{'__name__': '__main__',
'__doc__': None,
'__package__': None,
'__loader__': <class '_frozen_importlib.BuiltinImporter'>,
'__spec__': None, '__annotations__': {},
'__builtins__': <module 'builtins' (built-in)>,
'sys': <module 'sys' (built-in)>,
'Father': <class '__main__.Father'>}

_handle_legacy_output_transforms

**dynamic_kwargs 所接收的动态参数将会被存储为字典,所以output_transformdynamic_kwargs是两个字典,在运行pretrained.py时,二者的keyvalue分别为:

1
2
3
4
>>> print(output_transform)
{'func': <function convert_images_to_uint8 at 0x7f7dd9468378>, 'nchw_to_nhwc': True}
>>> print(dynamic_kwargs)
{'truncation_psi': 0.7, 'randomize_noise': True}

即分别为在Gs.run中输入的(truncation_psi=0.7, randomize_noise=True, output_transform=fmt)

*args 和 **kwargs

*args 和 **kwargs常用于获取函数的额外参数,通过以下的小例子我们可以清楚看出二者的区别:

1
2
3
4
5
6
7
8
9
10
def args_and_kwargs(*args, **kwargs):
print(type(args))
print(args)
print(type(kwargs))
print(kwargs)
>>> args_and_kwargs(6, 8, 'im args', num1 = 6, num2 = 8, strnum = 'im kwargs')
<class 'tuple'>
(6, 8, 'im args')
<class 'dict'>
{'num1': 6, 'num2': 8, 'strnum': 'im kwargs'}

_handle_legacy_output_transforms

该函数对output_transform, dynamic_kwargs两字典中的值进行判断,这里我们只看运行pretrained.py时该函数的内部调用情况:

1
2
3
4
5
6
_print_legacy_warning = True
def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
global _print_legacy_warning
legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
return output_transform, dynamic_kwargs

首先,any的用法如下:

1
2
3
4
5
6
7
8
9
10
11
"""
Signature: any(iterable, /)
Docstring:
若可迭代变量iterable中的任一元素为True,则返回True。
若iterable为空,则返回False
"""
>>> test = [True, False, False, False]
>>> any(test)
True
>>> not any(test)
False

所以当 dynamic_kwargslegacy_kwargs交集为空时,则原样返回。在刚开始看到类似kwarg in dynamic_kwargs for kwarg in legacy_kwargs的代码时可能会感到困惑,我们举一个简单的例子:

1
2
3
4
5
6
7
8
9
10
11
12
>>> a = [1, 2, 3, 4]
>>> b = [3, 4 ,5 ,6]
>>> c = [5, 6, 7, 8]
>>> d = [a in b for a in c]
[True, True, False, False]
# d等价于下列算法中的k
k=[]
for a in c:
if a in b:
k.append(True)
else:
k.append(False)

接下来的num_items = in_arrays[0].shape[0]是对获取输入latent矢量的个数(每个长度为512)。

key

input_transformoutput_transformnum_gpusassume_frozendynamic_kwargs包装为字典,其中第一、二、五位的键值为字典,即字典中的字典。

1
2
3
4
5
6
def unwind_key(obj):
if isinstance(obj, dict):
return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
if callable(obj):
return util.get_top_level_function_name(obj)
return obj

接着,作者定了unwind_key用于“解压”字典或者函数module,这里我们先介绍第一种情况,即obj为字典:

1
2
3
4
5
6
7
8
9
test = {'k1': 11,'k2': 22,
'k3': {'k1ink3': 'value13','k2ink3': 'value23',}}
# 这里我们定义一个test字典,其中k1、k2和k3位于一级字典中,k1ink3和k2ink3位于k3的二级字典中
>>> print(test)
{'k1': 11, 'k2': 22, 'k3': {'k1ink3': 'value13', 'k2ink3': 'value23'}}
>>> print(unwind_key(test))
[('k1', 11), ('k2', 22), ('k3', [('k1ink3', 'value13'), ('k2ink3', 'value23')])]
>>> repr(unwind_key(test))
[('k1', 11), ('k2', 22), ('k3', [('k1ink3', 'value13'), ('k2ink3', 'value23')])]

unwind_key(test)会把字典test转换为list,而repr(unwind_key(test))则将unwind_key(test)转换成了str

tfutil.absolute_name_scope

1
2
3
4
# ?
def absolute_name_scope(scope: str) -> tf.name_scope:
"""强制进入特定scope."""
return tf.name_scope(scope + "/")

tfutil.absolute_name_scope(self.scope + “/_Run”)创建了一个参数命名空间,在这个语境下,所有的参数名将变为“%s/_Run/variable_name”%(self.scope)的形式,即变为了Gs/_Run/xxx/xxx的形式。

注:在tfutil.py_sanitize_tf_config函数中添加cfg["log_device_placement"] = True即可输出所有变量的位置(CPU/GPU)。

tf.control_dependencies(control_inputs)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
>>> tf.control_dependencies(control_inputs)
"""
Signature: tf.control_dependencies(control_inputs)
Docstring:
Wrapper for `Graph.control_dependencies()` using the default graph.

See `tf.Graph.control_dependencies`
for more details.

当动态图机制启用时,会自动调用list`control_inputs`中的可执行对象。

Args:
control_inputs: op或者Tensor的列表,context中的内容会在control_inputs执行完成后才执行。也可以为‘None’来清空控制依赖

Returns:
返回一个上下文管理器。
"""

in_expr 和 in_split

1
2
3
with tf.device("/gpu:0"):
in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))

这里我做了修改,把所有变量操作都放到GPU上运行。
in_expr是一个长度为len(input_names) 的列表,其中的每一个变量都是类型为tf.float32,名称为self.input_names[i]tf.Tensor
in_split是将in_expr沿minibatch方向划分为num_gpustf.Tensor

多GPU环境的graph构建

for gpu in range(num_gpus): with tf.device("/gpu:%d" % gpu):即在每个GPU上执行该上下文,若只用了单卡,for循环不起作用且在"/gpu:0"上执行。net_gpu = self.clone() if assume_frozen else self,当assume_frozenFalse时将返回实例本身(默认),当其为True时会调用self.clone(),我们做一个简单的测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# 这里我们以源代码为模板,构造以下函数
import multiprocessing
class User(object):
def __init__(self,name):
self.name = name
self.cpu_count = multiprocessing.cpu_count()
print("本机共有%d个CPU。"%(self.cpu_count))

def clone(self, name: str = None):
new = object.__new__(User)
return new

def run(self,):
for cpu in range(self.cpu_count):
with tf.device("/cpu:%d"%cpu):
net_gpu = self.clone() if True else self
print('/cpu:%d/%s'%(cpu,str(net_gpu)))

>>> test = User("Test")
>>> test.run()
本机共有12个CPU。
/cpu:0/<__main__.User object at 0x7f4c3ceaf630>
/cpu:1/<__main__.User object at 0x7f4c3ceaf550>
/cpu:2/<__main__.User object at 0x7f4c3ceaf710>
/cpu:3/<__main__.User object at 0x7f4c3ceaf588>
/cpu:4/<__main__.User object at 0x7f4c3ceaf748>
/cpu:5/<__main__.User object at 0x7f4c3ceaf7f0>
/cpu:6/<__main__.User object at 0x7f4c3ceafa90>
/cpu:7/<__main__.User object at 0x7f4c3ceaf828>
/cpu:8/<__main__.User object at 0x7f4c3ceaf7b8>
/cpu:9/<__main__.User object at 0x7f4c3ceaf780>
/cpu:10/<__main__.User object at 0x7f4c3ce82278>
/cpu:11/<__main__.User object at 0x7f4c3ce825f8>

in_gpu = in_split[gpu]是从分割后的元组in_split中按照设备号来访问其中的tf.Tensor,也就是将整个miniBatch均匀分配到各个GPU上。if input_transform is not None:判断是否需要对输入做变换,in_kwargs = dict(input_transform)input_transform存放到字典in_kwargs中,实质上input_transform也为字典,这里可能是为了避免后续的pop操作将input_transform中的值删除。对输入做完预处理之后,若in_gputf.Tensortf.Variable或者tf.Operation中的某一个,则返回[in_gpu],否则返回list(in_gpu)

net_gpu.get_output_for的返回值out_gpu<tf.Tensor 'Gs/_Run/Gs/images_out:0' shape=(?, 3, 1024, 1024) dtype=float32>,由于篇幅原因我们将这一函数单独放在另一篇文章中讲解。

接下来的output_transform板块与上文中提到的input_transform所执行的操作比较相似,其是对输出out_gpu做了后续的处理,在运行pretrained.py时,run函数并没有执行if input_transform后面的代码,因为输入中并没有给出任何关于对输入信息进行转换的关键词或列表,但是需要将输出转换为unit8,即output_transform = {'func': <function convert_images_to_uint8 at 0x7f9f38697378>,'nchw_to_nhwc':True}代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
"""根据输入的动态范围drange,将一个minibatch的图像从float32类型转换为unit8。可以作为Network.run()中的输出转换使用
"""
images = tf.cast(images, tf.float32)
if shrink > 1:
ksize = [1, 1, shrink, shrink]
images = tf.nn.avg_pool(images,ksize=ksize,strides=ksize,padding="VALID",data_format="NWHC")
if nchw_to_nhwc:
images = tf.transpose(images, [0, 2, 3, 1])
scale = 255 / (drange[1] - drange[0])
images = images * scale + (0.5 - drange[0] * scale)
return tf.saturate_cast(images, tf.uint8)

我们来看一下上面代码块中出现的几个常用函数,首先看一下tf.cast以及tf.saturate_cast的区别:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Signature: tf.saturate_cast(value, dtype, name=None)
Docstring:
将value安全地转换为dtype类型,不进行任何缩放。当转换可能会出现泄露(overflow或 underflow)问题时,会首先将数据clamp到合适范围之内。
"""
#例如,我们要把一个float32转换为uint8类型:
>>> x = tf.constant([-1, 256], dtype=tf.float32)
>>> x1 = tf.saturate_cast(x, tf.uint8)
>>> x2 = tf.cast(x, tf.uint8)
>>> with tf.Session() as sess:
dx1, dx2 = sess.run([x1, x2])
print('tf.saturate_cast:',dx1)
print('tf.cast:',dx2)
tf.saturate_cast: [0 255]
tf.cast: [0 0]

然后,out_split.append(out_gpu)将每个GPU得到的结果out_gpu添加到表out_split中,由tf.concat(outputs, axis=0)将所有结果沿minibatch方向连接,最后转化为list赋值给out_expr