Skip to content

Commit b7a48ad

Browse files
authored
Merge pull request #121 from zerolovesea/autoint
feat(ranking model): Implementation of AutoInt model
2 parents 6d41ce6 + e15a44b commit b7a48ad

File tree

5 files changed

+184
-2
lines changed

5 files changed

+184
-2
lines changed

examples/ranking/run_criteo.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tqdm import tqdm
88

99
from torch_rechub.basic.features import DenseFeature, SparseFeature
10-
from torch_rechub.models.ranking import DCN, EDCN, DCNv2, DeepFFM, DeepFM, FatDeepFFM, FiBiNet, WideDeep
10+
from torch_rechub.models.ranking import DCN, EDCN, AutoInt, DCNv2, DeepFFM, DeepFM, FatDeepFFM, FiBiNet, WideDeep
1111
from torch_rechub.trainers import CTRTrainer
1212
from torch_rechub.utils.data import DataGenerator
1313

@@ -72,6 +72,8 @@ def main(dataset_path, model_name, epoch, learning_rate, batch_size, weight_deca
7272
model = FiBiNet(features=dense_feas + sparse_feas, reduction_ratio=3, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
7373
elif model_name == "edcn":
7474
model = EDCN(features=dense_feas + sparse_feas, n_cross_layers=3, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
75+
elif model_name == "autoint":
76+
model = AutoInt(dense_features=dense_feas, sparse_features=sparse_feas, num_layers=3, num_heads=2, dropout=0.2, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
7577
elif model_name == "deepffm":
7678
model = DeepFFM(linear_features=ffm_linear_feas, cross_features=ffm_cross_feas, embed_dim=10, mlp_params={"dims": [1600, 1600], "dropout": 0.5, "activation": "relu"})
7779
elif model_name == "fat_deepffm":
@@ -104,6 +106,7 @@ def main(dataset_path, model_name, epoch, learning_rate, batch_size, weight_deca
104106
python run_criteo.py --model_name dcn
105107
python run_criteo.py --model_name dcn_v2
106108
python run_criteo.py --model_name edcn
109+
python run_criteo.py --model_name autoint
107110
python run_criteo.py --model_name deepffm
108111
python run_criteo.py --model_name fat_deepffm
109112
"""

tests/test_e2e_ranking.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def test_ranking_e2e(model_class, ranking_data):
7272
params = {"features": features, "attention_dim": 16, "mlp_params": {"dims": [32]}}
7373
elif model_name == 'FiBiNet':
7474
params = {"features": features, "reduction_ratio": 3, "mlp_params": {"dims": [32]}}
75+
elif model_name == 'AutoInt':
76+
params = {"sparse_features": sparse_feats, "dense_features": dense_feats, "num_layers": 3, "num_heads": 2, "dropout": 0.0, "mlp_params": {"dims": [32]}}
7577
elif model_name in ["DeepFFM", "FatDeepFFM"]:
7678
# DeepFFM needs special features
7779
ffm_feats = [SparseFeature(f.name, f.vocab_size, 16) for f in sparse_feats]

torch_rechub/basic/layers.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,3 +719,77 @@ def forward(self, em):
719719
# [batch_size, num_field_crosses, embed_dim]
720720
aem = s.unsqueeze(-1) * em
721721
return aem.flatten(start_dim=1)
722+
723+
724+
class InteractingLayer(nn.Module):
725+
"""Multi-head Self-Attention based Interacting Layer, used in AutoInt model.
726+
727+
Args:
728+
embed_dim (int): the embedding dimension.
729+
num_heads (int): the number of attention heads (default=2).
730+
dropout (float): the dropout rate (default=0.0).
731+
residual (bool): whether to use residual connection (default=True).
732+
733+
Shape:
734+
- Input: `(batch_size, num_fields, embed_dim)`
735+
- Output: `(batch_size, num_fields, embed_dim)`
736+
"""
737+
738+
def __init__(self, embed_dim, num_heads=2, dropout=0.0, residual=True):
739+
super().__init__()
740+
if embed_dim % num_heads != 0:
741+
raise ValueError("embed_dim must be divisible by num_heads")
742+
743+
self.embed_dim = embed_dim
744+
self.num_heads = num_heads
745+
self.head_dim = embed_dim // num_heads
746+
self.scale = self.head_dim**-0.5
747+
self.residual = residual
748+
749+
self.W_Q = nn.Linear(embed_dim, embed_dim, bias=False)
750+
self.W_K = nn.Linear(embed_dim, embed_dim, bias=False)
751+
self.W_V = nn.Linear(embed_dim, embed_dim, bias=False)
752+
753+
# Residual connection
754+
self.W_Res = nn.Linear(embed_dim, embed_dim, bias=False) if residual else None
755+
self.dropout = nn.Dropout(dropout) if dropout > 0 else None
756+
757+
def forward(self, x):
758+
"""
759+
Args:
760+
x: input tensor with shape (batch_size, num_fields, embed_dim)
761+
"""
762+
batch_size, num_fields, embed_dim = x.shape
763+
764+
# Linear projections
765+
Q = self.W_Q(x) # (batch_size, num_fields, embed_dim)
766+
K = self.W_K(x) # (batch_size, num_fields, embed_dim)
767+
V = self.W_V(x) # (batch_size, num_fields, embed_dim)
768+
769+
# Reshape for multi-head attention
770+
# (batch_size, num_heads, num_fields, head_dim)
771+
Q = Q.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
772+
K = K.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
773+
V = V.view(batch_size, num_fields, self.num_heads, self.head_dim).transpose(1, 2)
774+
775+
# Scaled dot-product attention
776+
# (batch_size, num_heads, num_fields, num_fields)
777+
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
778+
attn_weights = F.softmax(attn_scores, dim=-1)
779+
780+
if self.dropout is not None:
781+
attn_weights = self.dropout(attn_weights)
782+
783+
# Apply attention to values
784+
# (batch_size, num_heads, num_fields, head_dim)
785+
attn_output = torch.matmul(attn_weights, V)
786+
787+
# Concatenate heads
788+
# (batch_size, num_fields, embed_dim)
789+
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, num_fields, embed_dim)
790+
791+
# Residual connection
792+
if self.residual and self.W_Res is not None:
793+
attn_output = attn_output + self.W_Res(x)
794+
795+
return F.relu(attn_output)

torch_rechub/models/ranking/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
__all__ = ['WideDeep', 'DeepFM', 'DCN', 'DCNv2', 'EDCN', 'AFM', 'FiBiNet', 'DeepFFM', 'BST', 'DIN', 'DIEN', 'FatDeepFFM']
1+
__all__ = ['WideDeep', 'DeepFM', 'DCN', 'DCNv2', 'EDCN', 'AFM', 'FiBiNet', 'DeepFFM', 'BST', 'DIN', 'DIEN', 'FatDeepFFM', 'AutoInt']
22

33
from .afm import AFM
4+
from .autoint import AutoInt
45
from .bst import BST
56
from .dcn import DCN
67
from .dcn_v2 import DCNv2
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""
2+
Date: create on 14/11/2025
3+
References:
4+
paper: (CIKM'2019) AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks
5+
url: https://arxiv.org/abs/1810.11921
6+
Authors: Yang Zhou, [email protected]
7+
"""
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
from ...basic.layers import LR, MLP, EmbeddingLayer, InteractingLayer
13+
14+
15+
class AutoInt(torch.nn.Module):
16+
"""AutoInt Model
17+
18+
Args:
19+
sparse_features (list): the list of `SparseFeature` Class
20+
dense_features (list): the list of `DenseFeature` Class
21+
num_layers (int): number of interacting layers
22+
num_heads (int): number of attention heads
23+
dropout (float): dropout rate for attention
24+
mlp_params (dict): parameters for MLP, keys: {"dims":list, "activation":str,
25+
"dropout":float, "output_layer":bool"}
26+
"""
27+
28+
def __init__(self, sparse_features, dense_features, num_layers=3, num_heads=2, dropout=0.0, mlp_params=None):
29+
super(AutoInt, self).__init__()
30+
self.sparse_features = sparse_features
31+
32+
self.dense_features = dense_features if dense_features is not None else []
33+
embed_dims = [fea.embed_dim for fea in self.sparse_features]
34+
self.embed_dim = embed_dims[0]
35+
if len(self.sparse_features) == 0:
36+
raise ValueError("AutoInt requires at least one sparse feature to determine embed_dim.")
37+
38+
# field nums = sparse + dense
39+
self.num_sparse = len(self.sparse_features)
40+
self.num_dense = len(self.dense_features)
41+
self.num_fields = self.num_sparse + self.num_dense
42+
43+
# total dims = num_fields * embed_dim
44+
self.dims = self.num_fields * self.embed_dim
45+
self.num_layers = num_layers
46+
47+
self.sparse_embedding = EmbeddingLayer(self.sparse_features)
48+
49+
# dense feature embedding
50+
self.dense_embeddings = nn.ModuleDict()
51+
for fea in self.dense_features:
52+
self.dense_embeddings[fea.name] = nn.Linear(1, self.embed_dim, bias=False)
53+
54+
self.interacting_layers = torch.nn.ModuleList([InteractingLayer(self.embed_dim, num_heads=num_heads, dropout=dropout, residual=True) for _ in range(num_layers)])
55+
56+
self.linear = LR(self.dims)
57+
58+
self.attn_linear = nn.Linear(self.dims, 1)
59+
60+
if mlp_params is not None:
61+
self.use_mlp = True
62+
self.mlp = MLP(self.dims, **mlp_params)
63+
else:
64+
self.use_mlp = False
65+
66+
def forward(self, x):
67+
# sparse feature embedding: [B, num_sparse, embed_dim]
68+
sparse_emb = self.sparse_embedding(x, self.sparse_features, squeeze_dim=False)
69+
70+
dense_emb_list = []
71+
for fea in self.dense_features:
72+
v = x[fea.name].float().view(-1, 1, 1)
73+
dense_emb = self.dense_embeddings[fea.name](v) # [B, 1, embed_dim]
74+
dense_emb_list.append(dense_emb)
75+
76+
if len(dense_emb_list) > 0:
77+
dense_emb = torch.cat(dense_emb_list, dim=1) # [B, num_dense, d]
78+
embed_x = torch.cat([sparse_emb, dense_emb], dim=1) # [B, num_fields, d]
79+
else:
80+
embed_x = sparse_emb # [B, num_sparse, d]
81+
82+
embed_x_flatten = embed_x.flatten(start_dim=1) # [B, num_fields * embed_dim]
83+
84+
# Multi-head self-attention layers
85+
attn_out = embed_x
86+
for layer in self.interacting_layers:
87+
attn_out = layer(attn_out) # [B, num_fields, embed_dim]
88+
89+
# Attention linear
90+
attn_out_flatten = attn_out.flatten(start_dim=1) # [B, num_fields * embed_dim]
91+
y_attn = self.attn_linear(attn_out_flatten) # [B, 1]
92+
93+
# Linear part
94+
y_linear = self.linear(embed_x_flatten) # [B, 1]
95+
96+
# Deep MLP
97+
y = y_attn + y_linear
98+
if self.use_mlp:
99+
y_deep = self.mlp(embed_x_flatten) # [B, 1]
100+
y = y + y_deep
101+
102+
return torch.sigmoid(y.squeeze(1))

0 commit comments

Comments
 (0)