🚀 30+ 主流推荐模型 | 🎯 开箱即用 | 📦 一键部署 ONNX | 🤖 支持生成式推荐 (HSTU/HLLM)
English | 简体中文
在线文档: https://datawhalechina.github.io/torch-rechub/ (英文)| https://datawhalechina.github.io/torch-rechub/zh/ (简体中文)
Torch-RecHub —— 10 行代码实现工业级推荐系统。30+ 主流模型开箱即用,支持一键 ONNX 部署,让你专注于业务而非工程。
| 特性 | Torch-RecHub | 其他框架 |
|---|---|---|
| 代码行数 | 10行 完成训练+评估+部署 | 100+ 行 |
| 模型覆盖 | 30+ 主流模型 | 有限 |
| 生成式推荐 | ✅ HSTU/HLLM (Meta 2024) | ❌ |
| ONNX 一键导出 | ✅ 内置支持 | 需手动适配 |
| 学习曲线 | 极低 | 陡峭 |
- 模块化设计: 易于添加新的模型、数据集和评估指标。
- 基于 PyTorch: 利用 PyTorch 的动态图和 GPU 加速能力。
- 丰富的模型库: 涵盖 30+ 经典和前沿推荐算法(召回、排序、多任务、生成式推荐等)。
- 标准化流程: 提供统一的数据加载、训练和评估流程。
- 易于配置: 通过配置文件或命令行参数轻松调整实验设置。
- 可复现性: 旨在确保实验结果的可复现性。
- ONNX 导出: 支持将训练好的模型导出为 ONNX 格式,便于部署到生产环境。
- 其他特性: 例如,支持负采样、多任务学习等。
- Python 3.9+
- PyTorch 1.7+ (建议使用支持 CUDA 的版本以获得 GPU 加速)
- NumPy
- Pandas
- SciPy
- Scikit-learn
稳定版(推荐用户使用):
pip install torch-rechub最新版:
# 首先安装 uv(如果尚未安装)
pip install uv
# 克隆并安装
git clone https://github.com/datawhalechina/torch-rechub.git
cd torch-rechub
uv sync以下是一个简单的示例,展示如何在 MovieLens 数据集上训练模型(例如 DSSM):
# 克隆仓库(如果使用最新版)
git clone https://github.com/datawhalechina/torch-rechub.git
cd torch-rechub
uv sync
# 运行示例
python examples/matching/run_ml_dssm.py
# 或使用自定义参数:
python examples/matching/run_ml_dssm.py --model_name dssm --device 'cuda:0' --learning_rate 0.001 --epoch 50 --batch_size 4096 --weight_decay 0.0001 --save_dir 'saved/dssm_ml-100k'训练完成后,模型文件将保存在 saved/dssm_ml-100k 目录下(或你配置的其他目录)。
torch-rechub/ # 根目录
├── README.md # 项目文档
├── pyproject.toml # 项目配置和依赖
├── torch_rechub/ # 核心代码库
│ ├── basic/ # 基础组件
│ │ ├── activation.py # 激活函数
│ │ ├── features.py # 特征工程
│ │ ├── layers.py # 神经网络层
│ │ ├── loss_func.py # 损失函数
│ │ └── metric.py # 评估指标
│ ├── models/ # 推荐模型实现
│ │ ├── matching/ # 召回模型(DSSM/MIND/GRU4Rec等)
│ │ ├── ranking/ # 排序模型(WideDeep/DeepFM/DIN等)
│ │ └── multi_task/ # 多任务模型(MMoE/ESMM等)
│ ├── trainers/ # 训练框架
│ │ ├── ctr_trainer.py # CTR预测训练器
│ │ ├── match_trainer.py # 召回模型训练器
│ │ └── mtl_trainer.py # 多任务学习训练器
│ └── utils/ # 工具函数
│ ├── data.py # 数据处理工具
│ ├── match.py # 召回工具
│ ├── mtl.py # 多任务工具
│ └── onnx_export.py # ONNX 导出工具
├── examples/ # 示例脚本
│ ├── matching/ # 召回任务示例
│ ├── ranking/ # 排序任务示例
│ └── generative/ # 生成式推荐示例(HSTU、HLLM 等)
├── docs/ # 文档(VitePress,多语言)
├── tutorials/ # Jupyter教程
├── tests/ # 单元测试
├── config/ # 配置文件
└── scripts/ # 工具脚本
本框架目前支持 30+ 主流推荐模型:
| 模型 | 论文 | 简介 |
|---|---|---|
| DeepFM | IJCAI 2017 | FM + Deep 联合训练 |
| Wide&Deep | DLRS 2016 | 记忆 + 泛化能力结合 |
| DCN | KDD 2017 | 显式特征交叉网络 |
| DCN-v2 | WWW 2021 | 增强版交叉网络 |
| DIN | KDD 2018 | 注意力机制捕捉用户兴趣 |
| DIEN | AAAI 2019 | 兴趣演化建模 |
| BST | DLP-KDD 2019 | Transformer 序列建模 |
| AFM | IJCAI 2017 | 注意力因子分解机 |
| AutoInt | CIKM 2019 | 自动特征交互学习 |
| FiBiNET | RecSys 2019 | 特征重要性 + 双线性交互 |
| DeepFFM | RecSys 2019 | 场感知因子分解机 |
| EDCN | KDD 2021 | 增强型交叉网络 |
| 模型 | 论文 | 简介 |
|---|---|---|
| DSSM | CIKM 2013 | 经典双塔召回模型 |
| YoutubeDNN | RecSys 2016 | YouTube 深度召回 |
| YoutubeSBC | RecSys 2019 | 采样偏差校正版本 |
| MIND | CIKM 2019 | 多兴趣动态路由 |
| SINE | WSDM 2021 | 稀疏兴趣网络 |
| GRU4Rec | ICLR 2016 | GRU 序列推荐 |
| SASRec | ICDM 2018 | 自注意力序列推荐 |
| NARM | CIKM 2017 | 神经注意力会话推荐 |
| STAMP | KDD 2018 | 短期注意力记忆优先 |
| ComiRec | KDD 2020 | 可控多兴趣推荐 |
| 模型 | 论文 | 简介 |
|---|---|---|
| ESMM | SIGIR 2018 | 全空间多任务建模 |
| MMoE | KDD 2018 | 多门控专家混合 |
| PLE | RecSys 2020 | 渐进式分层提取 |
| AITM | KDD 2021 | 自适应信息迁移 |
| SharedBottom | - | 经典多任务共享底层 |
| 模型 | 论文 | 简介 |
|---|---|---|
| HSTU | Meta 2024 | 层级序列转换单元,支撑 Meta 万亿参数推荐系统 |
| HLLM | 2024 | 层级大语言模型推荐,融合 LLM 语义理解能力 |
框架内置了对以下常见数据集格式的支持或提供了处理脚本:
- MovieLens
- Amazon
- Criteo
- Avazu
- Census-Income
- BookCrossing
- Ali-ccp
- Yidian
- ...
我们期望的数据格式通常是包含以下字段的交互文件:
- 用户 ID
- 物品 ID
- 评分(可选)
- 时间戳(可选)
具体格式要求请参考 tutorials 目录下的示例代码。
你可以方便地集成你自己的数据集,只需确保它符合框架要求的数据格式,或编写自定义的数据加载器。
所有模型使用案例参考 /examples
from torch_rechub.models.ranking import DeepFM
from torch_rechub.trainers import CTRTrainer
from torch_rechub.utils.data import DataGenerator
dg = DataGenerator(x, y)
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=256)
model = DeepFM(deep_features=deep_features, fm_features=fm_features, mlp_params={"dims": [256, 128], "dropout": 0.2, "activation": "relu"})
ctr_trainer = CTRTrainer(model)
ctr_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
ctr_trainer.export_onnx("deepfm.onnx")from torch_rechub.models.multi_task import SharedBottom, ESMM, MMOE, PLE, AITM
from torch_rechub.trainers import MTLTrainer
task_types = ["classification", "classification"]
model = MMOE(features, task_types, 8, expert_params={"dims": [32,16]}, tower_params_list=[{"dims": [32, 16]}, {"dims": [32, 16]}])
mtl_trainer = MTLTrainer(model)
mtl_trainer.fit(train_dataloader, val_dataloader)
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
mtl_trainer.export_onnx("mmoe.onnx")from torch_rechub.models.matching import DSSM
from torch_rechub.trainers import MatchTrainer
from torch_rechub.utils.data import MatchDataGenerator
dg = MatchDataGenerator(x, y)
train_dl, test_dl, item_dl = dg.generate_dataloader(test_user, all_item, batch_size=256)
model = DSSM(user_features, item_features, temperature=0.02,
user_params={
"dims": [256, 128, 64],
"activation": 'prelu',
},
item_params={
"dims": [256, 128, 64],
"activation": 'prelu',
})
match_trainer = MatchTrainer(model)
match_trainer.fit(train_dl)
match_trainer.export_onnx("dssm.onnx")
# 双塔模型可分别导出用户塔和物品塔:
# match_trainer.export_onnx("user_tower.onnx", mode="user")
# match_trainer.export_onnx("dssm_item.onnx", tower="item")# 可视化模型架构(需要安装: pip install torch-rechub[visualization])
graph = ctr_trainer.visualization(depth=4) # 生成计算图
ctr_trainer.visualization(save_path="model.pdf", dpi=300) # 保存为高清 PDF感谢所有的贡献者!
我们欢迎各种形式的贡献!请查看 CONTRIBUTING.md 了解详细的贡献指南。
我们也欢迎通过 Issues 报告 Bug 或提出功能建议。
本项目采用 MIT 许可证。
如果你在研究或工作中使用了本框架,请考虑引用:
@misc{torch_rechub,
title = {Torch-RecHub},
author = {Datawhale},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/datawhalechina/torch-rechub}},
note = {A PyTorch-based recommender system framework providing easy-to-use and extensible solutions}
}- 项目负责人: 1985312383
- GitHub Disscussions
最后更新: [2025-12-04]
