143 lines
5.0 KiB
Python
143 lines
5.0 KiB
Python
import logging, os
|
|
logging.disable(logging.WARNING)
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
"""This is the configuration file.
|
|
"""
|
|
|
|
|
|
################################################################################
|
|
# Settings for Basic Operaters
|
|
################################################################################
|
|
|
|
conf_basic_ops = dict()
|
|
|
|
# kernel_initializer for convolutions and transposed convolutions
|
|
# If None, the default initializer is the Glorot (Xavier) normal initializer.
|
|
conf_basic_ops['kernel_initializer'] = tf.glorot_uniform_initializer()
|
|
|
|
# momentum for batch normalization
|
|
conf_basic_ops['momentum'] = 0.997
|
|
|
|
# epsilon for batch normalization
|
|
conf_basic_ops['epsilon'] = 1e-5
|
|
|
|
# String options: 'relu', 'relu6'
|
|
conf_basic_ops['relu_type'] = 'relu'
|
|
|
|
################################################################################
|
|
# Settings for Attention Modules
|
|
################################################################################
|
|
|
|
# Set the attention in same_gto
|
|
conf_attn_same = dict()
|
|
|
|
# Define the relationship between total_key_filters and output_filters.
|
|
# total_key_filters = output_filters // key_ratio
|
|
conf_attn_same['key_ratio'] = 1
|
|
|
|
# Define the relationship between total_value_filters and output_filters.
|
|
# total_key_filters = output_filters // value_ratio
|
|
conf_attn_same['value_ratio'] = 1
|
|
|
|
# number of heads
|
|
conf_attn_same['num_heads'] = 2
|
|
|
|
# dropout rate, 0.0 means no dropout
|
|
conf_attn_same['dropout_rate'] = 0.0
|
|
|
|
# whether to use softmax on attention_weights
|
|
conf_attn_same['use_softmax'] = False
|
|
|
|
# whether to use bias terms in input/output transformations
|
|
conf_attn_same['use_bias'] = True
|
|
|
|
# Set the attention in up_gto
|
|
conf_attn_up = dict()
|
|
|
|
conf_attn_up['key_ratio'] = 1
|
|
conf_attn_up['value_ratio'] = 1
|
|
conf_attn_up['num_heads'] = 2
|
|
conf_attn_up['dropout_rate'] = 0
|
|
conf_attn_up['use_softmax'] = False
|
|
conf_attn_up['use_bias'] = True
|
|
|
|
# Set the attention in down_gto
|
|
conf_attn_down = dict()
|
|
|
|
conf_attn_down['key_ratio'] = 1
|
|
conf_attn_down['value_ratio'] = 1
|
|
conf_attn_down['num_heads'] = 2
|
|
conf_attn_down['dropout_rate'] = 0.0
|
|
conf_attn_down['use_softmax'] = False
|
|
conf_attn_down['use_bias'] = True
|
|
|
|
################################################################################
|
|
# Describing the U-net
|
|
################################################################################
|
|
|
|
conf_unet = dict()
|
|
|
|
"""
|
|
Describe your U-Net under the following framework:
|
|
|
|
********************************************************************************************
|
|
layers | output_filters
|
|
|
|
|
first_convolution + encoding_block_1 (same) | first_output_filters
|
|
+ encoding_block_i, i = 2, 3, ..., depth. (down) | first_output_filters*(2**(i-1))
|
|
+ bottom_block | first_output_filters*(2**(depth-1))
|
|
+ decoding_block_j, j = depth-1, depth-2, ..., 1 (up) | first_output_filters*(2**(j-1))
|
|
+ output_layer
|
|
********************************************************************************************
|
|
|
|
Specifically,
|
|
encoding_block_1 (same) = one or more res_block
|
|
encoding_block_i (down) = downsampling + zero or more res_block, i = 2, 3, ..., depth-1
|
|
encoding_block_depth (down) = downsampling
|
|
bottom_block = a combination of same_gto and res_block
|
|
decoding_block_j (up) = upsampling + zero or more res_block, j = depth-1, depth-2, ..., 1
|
|
|
|
Identity skip connections are between the output of encoding_block_i and
|
|
the output of upsampling in decoding_block_i, i = 1, 2, ..., depth-1.
|
|
The combination method could be 'add' or 'concat'.
|
|
"""
|
|
|
|
# Set U-Net depth.
|
|
conf_unet['depth'] = 3
|
|
|
|
# Set the output_filters for first_convolution and encoding_block_1 (same).
|
|
conf_unet['first_output_filters'] = 96
|
|
|
|
# Set the encoding block sizes, i.e., number of res_block in encoding_block_i, i = 1, 2, ..., depth.
|
|
# It is an integer list whose length equals to depth.
|
|
# The first entry should be positive since encoding_block_1 = one or more res_block.
|
|
# The last entry should be zero since encoding_block_depth (down) = downsampling.
|
|
conf_unet['encoding_block_sizes'] = [1, 1, 0]
|
|
|
|
# Set the decoding block sizes, i.e., number of res_block in decoding_block_j, j = depth-1, depth-2, ..., 1.
|
|
# It is an integer list whose length equals to depth-1.
|
|
conf_unet['decoding_block_sizes'] = [1, 1]
|
|
|
|
# Set the downsampling methods for each encoding_block_i, i = 2, 3, ..., depth.
|
|
# It is an string list whose length equals to depth-1.
|
|
# String options: 'down_gto_v1', 'down_gto_v2', 'down_res_block', 'convolution'
|
|
conf_unet['downsampling'] = ['convolution', 'convolution']
|
|
|
|
# Set the combination method for identity skip connections
|
|
# String options: 'add', 'concat'
|
|
conf_unet['skip_method'] = 'concat'
|
|
|
|
# Set the output layer
|
|
|
|
|
|
# Check
|
|
assert conf_unet['depth'] == len(conf_unet['encoding_block_sizes'])
|
|
assert conf_unet['encoding_block_sizes'][0] > 0
|
|
assert conf_unet['encoding_block_sizes'][-1] == 0
|
|
assert conf_unet['depth'] == len(conf_unet['decoding_block_sizes']) + 1
|
|
assert conf_unet['depth'] == len(conf_unet['downsampling']) + 1
|
|
assert conf_unet['skip_method'] in ['add', 'concat'] |