Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions deepctr/feature_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,20 @@ def input_from_feature_columns(features, feature_columns, l2_reg, seed, prefix='
if not support_group:
group_embedding_dict = list(chain.from_iterable(group_embedding_dict.values()))
return group_embedding_dict, dense_value_list


def input_from_seq_feature_columns(features, feature_columns, l2_reg, seed, prefix='', seq_mask_zero=True,
support_dense=True, support_group=True):

varlen_sparse_feature_columns = list(
filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else []

embedding_matrix_dict = create_embedding_matrix(feature_columns, l2_reg, seed, prefix=prefix,
seq_mask_zero=seq_mask_zero)
dense_value_list = get_dense_input(features, feature_columns)
if not support_dense and len(dense_value_list) > 0:
raise ValueError("DenseFeat is not supported in dnn_feature_columns")

sequence_embed_dict = varlen_embedding_lookup(embedding_matrix_dict, features, varlen_sparse_feature_columns)
return sequence_embed_dict, dense_value_list

2 changes: 1 addition & 1 deletion deepctr/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet, CrossNetMix,
InnerProductLayer, InteractingLayer,
OutterProductLayer, FGCNNLayer, SENETLayer, BilinearInteraction,
FieldWiseBiInteraction, FwFMLayer, FEFMLayer)
FieldWiseBiInteraction, FwFMLayer, FEFMLayer, CoActionLayer)
from .normalization import LayerNormalization
from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM,
KMaxPooling, SequencePoolingLayer, WeightedSequenceLayer,
Expand Down
80 changes: 80 additions & 0 deletions deepctr/layers/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,3 +1489,83 @@ def get_config(self):
'regularizer': self.regularizer,
})
return config

class CoActionLayer(Layer):
"""
build co-action output for target item(type) and user pref seq item(type)
Input shape
- 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.

Output shape
- 2D tensor with shape: ``(batch_size, 2nd_mlp_dims)``.

References
- [CAN: Feature Co-Action for Click-Through Rate Prediction](https://arxiv.org/abs/2011.05625)
"""

def __init__(self, target_input, co_action_config, name, **kwargs):

self._target_input = target_input
self._co_action_config = co_action_config
self._layer_name = name
self._weight_orders, self._bias_orders = self._build_mlp()

super(CoActionLayer, self).__init__(**kwargs)

def build(self, input_shape):
if len(input_shape) != 3:
raise ValueError("Unexpected inputs dimensions % d,\
expect to be 3 dimensions" % (len(input_shape)))

super(CoActionLayer, self).build(input_shape) # Be sure to call this somewhere!

def _build_mlp(self):
target_emb = tf.reduce_sum(self._target_input, axis=1) # avoid target varlength
weight_orders, bias_orders = [], []
idx = 0
for i in range(self._co_action_config['orders']):
weight, bias = [], []
for w, b in zip(self._co_action_config['target_emb_w'], self._co_action_config['target_emb_b']):
weight.append(tf.reshape(target_emb[:, idx:idx + w[0] * w[1]], [-1, w[0], w[1]]))
idx += w[0] * w[1]
if b == 0:
bias.append(None)
else:
bias.append(tf.reshape(target_emb[:, idx:idx + b], [-1, 1, b]))
idx += b
weight_orders.append(weight)
bias_orders.append(bias)
if not self._co_action_config['indep_action']:
break
return weight_orders, bias_orders

def co_action_op(self, hist_pref_seq, mask=None):
inputs = []
for i in range(self._co_action_config['orders']):
inputs.append(tf.math.pow(hist_pref_seq, i + 1.0))
out_seq = []
for i, h in enumerate(inputs):
if self._co_action_config['indep_action']:
weight, bias = self._weight_orders[i], self._bias_orders[i]
else:
weight, bias = self._weight_orders[0], self._bias_orders[0]
for j, (w, b) in enumerate(zip(weight, bias)):
h = tf.matmul(h, w)
if b is not None:
h = h + b
if j != len(weight) - 1:
h = tf.nn.tanh(h)
out_seq.append(h)
out_seq = tf.concat(out_seq, 2)
if mask is not None:
mask = tf.expand_dims(mask, axis=-1)
out_seq = out_seq * mask
out = tf.reduce_sum(out_seq, 1)
return out

def call(self, inputs, **kwargs):
result = self.co_action_op(inputs, mask=None)
return result

def compute_output_shape(self, input_shape):
return (None, 1)
5 changes: 3 additions & 2 deletions deepctr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from .multitask import SharedBottom, ESMM, MMOE, PLE
from .nfm import NFM
from .onn import ONN
from .can import CAN
from .pnn import PNN
from .sequence import DIN, DIEN, DSIN, BST
from .wdl import WDL
from .xdeepfm import xDeepFM

__all__ = ["AFM", "CCPM", "DCN", "IFM", "DIFM", "DCNMix", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", "PNN",
"WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM", "BST", "DeepFEFM",
__all__ = ["AFM", "CCPM", "DCN", "IFM", "DIFM", "DCNMix", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", 'CAN',
"PNN", "WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM", "BST", "DeepFEFM",
"SharedBottom", "ESMM", "MMOE", "PLE"]
74 changes: 74 additions & 0 deletions deepctr/models/can.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# -*- coding:utf-8 -*-
"""

Author:
Weichen Shen, [email protected]

Reference:
[1] Xiao J, Ye H, He X, et al. Attentional factorization machines: Learning the weight of feature interactions via attention networks[J]. arXiv preprint arXiv:1708.04617, 2017.
(https://arxiv.org/abs/1708.04617)

"""
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Dense
from ..feature_column import build_input_features, input_from_seq_feature_columns
from ..layers.core import PredictionLayer
from ..layers.interaction import CoActionLayer
from ..layers.utils import concat_func


# template_can_config = [
# {
# 'name': 'co_action_for_item',
# 'target': 'item_id', # target emb need to reshape
# 'pref_seq': ['hist_item_id', ], # seq emb need to co-action
# 'co_action_conf': {
# 'target_emb_w': [[16, 8], [8, 4]],
# 'target_emb_b': [0, 0],
# 'indep_action': False,
# 'orders': 3, # exp non_linear trans
# }
# },
# {
# 'name': 'co_action_for_cate',
# 'target': 'cate_id',
# 'pref_seq': ['hist_cate_id', ],
# 'co_action_conf': {
# 'target_emb_w': [[16, 8], [8, 4]],
# 'target_emb_b': [0, 0],
# 'indep_action': False,
# 'orders': 3, # exp non_linear trans
# }
# }
# ]


def CAN(dnn_feature_columns, co_action_config, l2_reg_embedding=1e-5, seed=1024, task='binary'):
"""Instantiates the CAN architecture.

:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
:param co_action_config: A dict containing the bindings with all the features(target, hist_pref_seq) .
:param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector
:param seed: integer ,to use as random seed.
:param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss
:return: A Keras model instance.
"""

features = build_input_features(dnn_feature_columns)
inputs_list = list(features.values())
sequence_embed_dict, _ = input_from_seq_feature_columns(features, dnn_feature_columns, l2_reg_embedding,
seed, prefix='', support_dense=False)

# co-action for target type with multi hist pref seq
can_output_list = []
for conf in co_action_config:
cur_can_layer = CoActionLayer(sequence_embed_dict[conf['target']], conf['co_action_conf'], name=conf['name'])
print(cur_can_layer._layer_name)
for his_pref_seq in conf['pref_seq']:
can_output_list.append(cur_can_layer(sequence_embed_dict[his_pref_seq]))

can_output = concat_func(can_output_list)
final_logit = Dense(1, use_bias=False)(can_output)
output = PredictionLayer(task)(final_logit)
model = Model(inputs=inputs_list, outputs=output)
return model
70 changes: 70 additions & 0 deletions examples/run_can.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import numpy as np

from deepctr.models import CAN
from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat, get_feature_names

# for tf2.x
import tensorflow as tf
tf.compat.v1.disable_eager_execution()

def get_xy_fd():

feature_columns = [
# target emb size: sum([w[0] * w[1] for w in can_config['target_emb_w']]) + sum(can_config['target_emb_b'])
VarLenSparseFeat(SparseFeat('item_id', 3 + 1, embedding_dim=190), maxlen=1, length_name="target_length"),
VarLenSparseFeat(SparseFeat('cate_id', 2 + 1, embedding_dim=190), maxlen=2, length_name="cate_target_length"),

# hist pref seq emb size : can_config['target_emb_w'][0][0],
VarLenSparseFeat(SparseFeat('hist_item_id', 3 + 1, embedding_dim=16), maxlen=4, length_name="seq_length"),
VarLenSparseFeat(SparseFeat('hist_cate_id', 2 + 1, embedding_dim=16), maxlen=4, length_name="seq_length")]
# Notice: History behavior sequence feature name must start with "hist_".
behavior_feature_list = ["item_id", "cate_id"]
iid = np.array([1, 2, 3]) # target 1
cate_id = np.array([[1,2], [2,0], [2,0]]) # cate maybe multi
hist_iid = np.array([[1, 2, 3, 0], [3, 2, 1, 0], [1, 2, 0, 0]])
hist_cate_id = np.array([[1, 2, 2, 0], [2, 2, 1, 0], [1, 2, 0, 0]])
seq_length = np.array([3, 3, 2]) # the actual length of the behavior sequence
target_length = np.array([1, 1, 1])
cate_target_length = np.array([2, 1, 1])

feature_dict = {'item_id': iid, 'cate_id': cate_id,
'hist_item_id': hist_iid, 'hist_cate_id': hist_cate_id,
'seq_length': seq_length, 'target_length': target_length,
'cate_target_length': cate_target_length}
x = {name: feature_dict[name] for name in get_feature_names(feature_columns)}
y = np.array([1, 0, 1])
return x, y, feature_columns, behavior_feature_list


if __name__ == "__main__":
x, y, feature_columns, behavior_feature_list = get_xy_fd()

co_action_config = [
{
'name': 'co_action_for_item',
'target': 'item_id', # target emb need to reshape
'pref_seq': ['hist_item_id', ], # seq emb need to co-action
'co_action_conf': {
'target_emb_w': [[16, 8], [8, 4]],
'target_emb_b': [0, 0],
'indep_action': False,
'orders': 3, # exp non_linear trans
}
},
{
'name': 'co_action_for_cate',
'target': 'cate_id',
'pref_seq': ['hist_cate_id', ],
'co_action_conf': {
'target_emb_w': [[16, 8], [8, 4]],
'target_emb_b': [0, 0],
'indep_action': False,
'orders': 3, # exp non_linear trans
}
}
]

model = CAN(feature_columns, co_action_config=co_action_config)
model.compile('adam', 'binary_crossentropy',
metrics=['binary_crossentropy'])
history = model.fit(x, y, verbose=1, epochs=10, validation_split=0.5)
5 changes: 5 additions & 0 deletions tests/layers/interaction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,8 @@ def test_BilinearInteraction(bilinear_type):
with CustomObjectScope({'BilinearInteraction': layers.BilinearInteraction}):
layer_test(layers.BilinearInteraction, kwargs={'bilinear_type': bilinear_type}, input_shape=[(
BATCH_SIZE, 1, EMBEDDING_SIZE)] * FIELD_SIZE)

def test_CAN():
with CustomObjectScope({'CAN': layers.CoActionLayer}):
layer_test(layers.CoActionLayer, kwargs={}, input_shape=(
BATCH_SIZE, FIELD_SIZE, EMBEDDING_SIZE))