diff --git a/deepctr/models/fgcnn.py b/deepctr/models/fgcnn.py index 3ee1eaa4..80dc07bf 100644 --- a/deepctr/models/fgcnn.py +++ b/deepctr/models/fgcnn.py @@ -70,8 +70,12 @@ def FGCNN(linear_feature_columns, dnn_feature_columns, conv_kernel_width=(7, 7, combined_input = concat_func([origin_input, new_features], axis=1) else: combined_input = origin_input - inner_product = tf.keras.layers.Flatten()(InnerProductLayer()( - tf.keras.layers.Lambda(unstack, mask=[None] * int(combined_input.shape[1]))(combined_input))) + + #inner_product = tf.keras.layers.Flatten()(InnerProductLayer()( + # tf.keras.layers.Lambda(unstack, mask=[None] * int(combined_input.shape[1]))(combined_input))) + + inner_product = tf.keras.layers.Flatten()(InnerProductLayer()(tf.split(combined_input,combined_input.shape[1],1))) + linear_signal = tf.keras.layers.Flatten()(combined_input) dnn_input = tf.keras.layers.Concatenate()([linear_signal, inner_product]) dnn_input = tf.keras.layers.Flatten()(dnn_input)