diff --git a/nets/deep_sort/residual_net.py b/nets/deep_sort/residual_net.py index 8afa1c67..750d3e47 100644 --- a/nets/deep_sort/residual_net.py +++ b/nets/deep_sort/residual_net.py @@ -10,9 +10,10 @@ def _batch_norm_fn(x, scope=None): def create_link( - incoming, network_builder, scope, nonlinearity=tf.nn.elu, + incoming, scope, nonlinearity=tf.nn.elu, weights_initializer=tf.truncated_normal_initializer(stddev=1e-3), - regularizer=None, is_first=False, summarize_activations=True): + i_block_bias_initializer, i_block_increase_dim, regularizer=None, + is_first=False, summarize_activations=True): if is_first: network = incoming else: @@ -22,8 +23,9 @@ def create_link( tf.summary.histogram(scope+"/activations", network) pre_block_network = incoming - post_block_network = network_builder(network, scope) - + post_block_network = create_inner_block( + network, scope, nonlinearity, weights_initializer, i_block_bias_initializer, + regularizer, i_block_increase_dim, summarize_activations) incoming_dim = pre_block_network.get_shape().as_list()[-1] outgoing_dim = post_block_network.get_shape().as_list()[-1] if incoming_dim != outgoing_dim: @@ -74,11 +76,7 @@ def residual_block(incoming, scope, nonlinearity=tf.nn.elu, increase_dim=False, is_first=False, summarize_activations=True): - def network_builder(x, s): - return create_inner_block( - x, s, nonlinearity, weights_initializer, bias_initializer, - regularizer, increase_dim, summarize_activations) - return create_link( - incoming, network_builder, scope, nonlinearity, weights_initializer, + incoming, scope, nonlinearity, weights_initializer, + bias_initializer, increase_dim, regularizer, is_first, summarize_activations)