diff --git a/deepctr/feature_column.py b/deepctr/feature_column.py index 5cc1930e..b79eda12 100644 --- a/deepctr/feature_column.py +++ b/deepctr/feature_column.py @@ -212,3 +212,20 @@ def input_from_feature_columns(features, feature_columns, l2_reg, seed, prefix=' if not support_group: group_embedding_dict = list(chain.from_iterable(group_embedding_dict.values())) return group_embedding_dict, dense_value_list + + +def input_from_seq_feature_columns(features, feature_columns, l2_reg, seed, prefix='', seq_mask_zero=True, + support_dense=True, support_group=True): + + varlen_sparse_feature_columns = list( + filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else [] + + embedding_matrix_dict = create_embedding_matrix(feature_columns, l2_reg, seed, prefix=prefix, + seq_mask_zero=seq_mask_zero) + dense_value_list = get_dense_input(features, feature_columns) + if not support_dense and len(dense_value_list) > 0: + raise ValueError("DenseFeat is not supported in dnn_feature_columns") + + sequence_embed_dict = varlen_embedding_lookup(embedding_matrix_dict, features, varlen_sparse_feature_columns) + return sequence_embed_dict, dense_value_list + diff --git a/deepctr/layers/__init__.py b/deepctr/layers/__init__.py index 1bfd40ef..2e32e551 100644 --- a/deepctr/layers/__init__.py +++ b/deepctr/layers/__init__.py @@ -5,7 +5,7 @@ from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet, CrossNetMix, InnerProductLayer, InteractingLayer, OutterProductLayer, FGCNNLayer, SENETLayer, BilinearInteraction, - FieldWiseBiInteraction, FwFMLayer, FEFMLayer) + FieldWiseBiInteraction, FwFMLayer, FEFMLayer, CoActionLayer) from .normalization import LayerNormalization from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM, KMaxPooling, SequencePoolingLayer, WeightedSequenceLayer, diff --git a/deepctr/layers/interaction.py b/deepctr/layers/interaction.py index d26eb2c1..a0c6a76e 100644 --- a/deepctr/layers/interaction.py +++ b/deepctr/layers/interaction.py @@ -1489,3 +1489,83 @@ def get_config(self): 'regularizer': self.regularizer, }) return config + +class CoActionLayer(Layer): + """ + build co-action output for target item(type) and user pref seq item(type) + Input shape + - 3D tensor with shape: ``(batch_size,field_size,embedding_size)``. + + Output shape + - 2D tensor with shape: ``(batch_size, 2nd_mlp_dims)``. + + References + - [CAN: Feature Co-Action for Click-Through Rate Prediction](https://arxiv.org/abs/2011.05625) + """ + + def __init__(self, target_input, co_action_config, name, **kwargs): + + self._target_input = target_input + self._co_action_config = co_action_config + self._layer_name = name + self._weight_orders, self._bias_orders = self._build_mlp() + + super(CoActionLayer, self).__init__(**kwargs) + + def build(self, input_shape): + if len(input_shape) != 3: + raise ValueError("Unexpected inputs dimensions % d,\ + expect to be 3 dimensions" % (len(input_shape))) + + super(CoActionLayer, self).build(input_shape) # Be sure to call this somewhere! + + def _build_mlp(self): + target_emb = tf.reduce_sum(self._target_input, axis=1) # avoid target varlength + weight_orders, bias_orders = [], [] + idx = 0 + for i in range(self._co_action_config['orders']): + weight, bias = [], [] + for w, b in zip(self._co_action_config['target_emb_w'], self._co_action_config['target_emb_b']): + weight.append(tf.reshape(target_emb[:, idx:idx + w[0] * w[1]], [-1, w[0], w[1]])) + idx += w[0] * w[1] + if b == 0: + bias.append(None) + else: + bias.append(tf.reshape(target_emb[:, idx:idx + b], [-1, 1, b])) + idx += b + weight_orders.append(weight) + bias_orders.append(bias) + if not self._co_action_config['indep_action']: + break + return weight_orders, bias_orders + + def co_action_op(self, hist_pref_seq, mask=None): + inputs = [] + for i in range(self._co_action_config['orders']): + inputs.append(tf.math.pow(hist_pref_seq, i + 1.0)) + out_seq = [] + for i, h in enumerate(inputs): + if self._co_action_config['indep_action']: + weight, bias = self._weight_orders[i], self._bias_orders[i] + else: + weight, bias = self._weight_orders[0], self._bias_orders[0] + for j, (w, b) in enumerate(zip(weight, bias)): + h = tf.matmul(h, w) + if b is not None: + h = h + b + if j != len(weight) - 1: + h = tf.nn.tanh(h) + out_seq.append(h) + out_seq = tf.concat(out_seq, 2) + if mask is not None: + mask = tf.expand_dims(mask, axis=-1) + out_seq = out_seq * mask + out = tf.reduce_sum(out_seq, 1) + return out + + def call(self, inputs, **kwargs): + result = self.co_action_op(inputs, mask=None) + return result + + def compute_output_shape(self, input_shape): + return (None, 1) diff --git a/deepctr/models/__init__.py b/deepctr/models/__init__.py index 2d19714b..d3c35ee3 100644 --- a/deepctr/models/__init__.py +++ b/deepctr/models/__init__.py @@ -16,11 +16,12 @@ from .multitask import SharedBottom, ESMM, MMOE, PLE from .nfm import NFM from .onn import ONN +from .can import CAN from .pnn import PNN from .sequence import DIN, DIEN, DSIN, BST from .wdl import WDL from .xdeepfm import xDeepFM -__all__ = ["AFM", "CCPM", "DCN", "IFM", "DIFM", "DCNMix", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", "PNN", - "WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM", "BST", "DeepFEFM", +__all__ = ["AFM", "CCPM", "DCN", "IFM", "DIFM", "DCNMix", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", 'CAN', + "PNN", "WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM", "BST", "DeepFEFM", "SharedBottom", "ESMM", "MMOE", "PLE"] diff --git a/deepctr/models/can.py b/deepctr/models/can.py new file mode 100644 index 00000000..a1a3b6e3 --- /dev/null +++ b/deepctr/models/can.py @@ -0,0 +1,74 @@ +# -*- coding:utf-8 -*- +""" + +Author: + Weichen Shen, weichenswc@163.com + +Reference: + [1] Xiao J, Ye H, He X, et al. Attentional factorization machines: Learning the weight of feature interactions via attention networks[J]. arXiv preprint arXiv:1708.04617, 2017. + (https://arxiv.org/abs/1708.04617) + +""" +from tensorflow.python.keras.models import Model +from tensorflow.python.keras.layers import Dense +from ..feature_column import build_input_features, input_from_seq_feature_columns +from ..layers.core import PredictionLayer +from ..layers.interaction import CoActionLayer +from ..layers.utils import concat_func + + +# template_can_config = [ +# { +# 'name': 'co_action_for_item', +# 'target': 'item_id', # target emb need to reshape +# 'pref_seq': ['hist_item_id', ], # seq emb need to co-action +# 'co_action_conf': { +# 'target_emb_w': [[16, 8], [8, 4]], +# 'target_emb_b': [0, 0], +# 'indep_action': False, +# 'orders': 3, # exp non_linear trans +# } +# }, +# { +# 'name': 'co_action_for_cate', +# 'target': 'cate_id', +# 'pref_seq': ['hist_cate_id', ], +# 'co_action_conf': { +# 'target_emb_w': [[16, 8], [8, 4]], +# 'target_emb_b': [0, 0], +# 'indep_action': False, +# 'orders': 3, # exp non_linear trans +# } +# } +# ] + + +def CAN(dnn_feature_columns, co_action_config, l2_reg_embedding=1e-5, seed=1024, task='binary'): + """Instantiates the CAN architecture. + + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param co_action_config: A dict containing the bindings with all the features(target, hist_pref_seq) . + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector + :param seed: integer ,to use as random seed. + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss + :return: A Keras model instance. + """ + + features = build_input_features(dnn_feature_columns) + inputs_list = list(features.values()) + sequence_embed_dict, _ = input_from_seq_feature_columns(features, dnn_feature_columns, l2_reg_embedding, + seed, prefix='', support_dense=False) + + # co-action for target type with multi hist pref seq + can_output_list = [] + for conf in co_action_config: + cur_can_layer = CoActionLayer(sequence_embed_dict[conf['target']], conf['co_action_conf'], name=conf['name']) + print(cur_can_layer._layer_name) + for his_pref_seq in conf['pref_seq']: + can_output_list.append(cur_can_layer(sequence_embed_dict[his_pref_seq])) + + can_output = concat_func(can_output_list) + final_logit = Dense(1, use_bias=False)(can_output) + output = PredictionLayer(task)(final_logit) + model = Model(inputs=inputs_list, outputs=output) + return model diff --git a/examples/run_can.py b/examples/run_can.py new file mode 100644 index 00000000..4d494c64 --- /dev/null +++ b/examples/run_can.py @@ -0,0 +1,70 @@ +import numpy as np + +from deepctr.models import CAN +from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat, get_feature_names + +# for tf2.x +import tensorflow as tf +tf.compat.v1.disable_eager_execution() + +def get_xy_fd(): + + feature_columns = [ + # target emb size: sum([w[0] * w[1] for w in can_config['target_emb_w']]) + sum(can_config['target_emb_b']) + VarLenSparseFeat(SparseFeat('item_id', 3 + 1, embedding_dim=190), maxlen=1, length_name="target_length"), + VarLenSparseFeat(SparseFeat('cate_id', 2 + 1, embedding_dim=190), maxlen=2, length_name="cate_target_length"), + + # hist pref seq emb size : can_config['target_emb_w'][0][0], + VarLenSparseFeat(SparseFeat('hist_item_id', 3 + 1, embedding_dim=16), maxlen=4, length_name="seq_length"), + VarLenSparseFeat(SparseFeat('hist_cate_id', 2 + 1, embedding_dim=16), maxlen=4, length_name="seq_length")] + # Notice: History behavior sequence feature name must start with "hist_". + behavior_feature_list = ["item_id", "cate_id"] + iid = np.array([1, 2, 3]) # target 1 + cate_id = np.array([[1,2], [2,0], [2,0]]) # cate maybe multi + hist_iid = np.array([[1, 2, 3, 0], [3, 2, 1, 0], [1, 2, 0, 0]]) + hist_cate_id = np.array([[1, 2, 2, 0], [2, 2, 1, 0], [1, 2, 0, 0]]) + seq_length = np.array([3, 3, 2]) # the actual length of the behavior sequence + target_length = np.array([1, 1, 1]) + cate_target_length = np.array([2, 1, 1]) + + feature_dict = {'item_id': iid, 'cate_id': cate_id, + 'hist_item_id': hist_iid, 'hist_cate_id': hist_cate_id, + 'seq_length': seq_length, 'target_length': target_length, + 'cate_target_length': cate_target_length} + x = {name: feature_dict[name] for name in get_feature_names(feature_columns)} + y = np.array([1, 0, 1]) + return x, y, feature_columns, behavior_feature_list + + +if __name__ == "__main__": + x, y, feature_columns, behavior_feature_list = get_xy_fd() + + co_action_config = [ + { + 'name': 'co_action_for_item', + 'target': 'item_id', # target emb need to reshape + 'pref_seq': ['hist_item_id', ], # seq emb need to co-action + 'co_action_conf': { + 'target_emb_w': [[16, 8], [8, 4]], + 'target_emb_b': [0, 0], + 'indep_action': False, + 'orders': 3, # exp non_linear trans + } + }, + { + 'name': 'co_action_for_cate', + 'target': 'cate_id', + 'pref_seq': ['hist_cate_id', ], + 'co_action_conf': { + 'target_emb_w': [[16, 8], [8, 4]], + 'target_emb_b': [0, 0], + 'indep_action': False, + 'orders': 3, # exp non_linear trans + } + } + ] + + model = CAN(feature_columns, co_action_config=co_action_config) + model.compile('adam', 'binary_crossentropy', + metrics=['binary_crossentropy']) + history = model.fit(x, y, verbose=1, epochs=10, validation_split=0.5) diff --git a/tests/layers/interaction_test.py b/tests/layers/interaction_test.py index 5f162f42..7910975a 100644 --- a/tests/layers/interaction_test.py +++ b/tests/layers/interaction_test.py @@ -147,3 +147,8 @@ def test_BilinearInteraction(bilinear_type): with CustomObjectScope({'BilinearInteraction': layers.BilinearInteraction}): layer_test(layers.BilinearInteraction, kwargs={'bilinear_type': bilinear_type}, input_shape=[( BATCH_SIZE, 1, EMBEDDING_SIZE)] * FIELD_SIZE) + +def test_CAN(): + with CustomObjectScope({'CAN': layers.CoActionLayer}): + layer_test(layers.CoActionLayer, kwargs={}, input_shape=( + BATCH_SIZE, FIELD_SIZE, EMBEDDING_SIZE)) \ No newline at end of file