Skip to content

Commit 222d87d

Browse files
authored
Merge pull request #122 from 1985312383/main
Add HSTU model
2 parents b7a48ad + c30a1ab commit 222d87d

File tree

13 files changed

+1815
-3
lines changed

13 files changed

+1815
-3
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ jobs:
6060
- name: Install Dependencies
6161
run: |
6262
pip install --upgrade pip
63-
# 安装特定版本的格式化工具以确保一致性
64-
pip install yapf==0.32.0 isort==5.10.1 flake8>=3.8.0 mypy>=0.800 toml>=0.10.2
63+
# 安装特定版本的格式化工具以确保一致性(与 pyproject.toml 保持一致)
64+
pip install yapf==0.43.0 isort==5.13.2 flake8>=3.8.0 mypy>=0.800 toml>=0.10.2
6565
6666
- name: Format & Lint
6767
run: |

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ share/python-wheels/
2727
*.egg-info/
2828
.installed.cfg
2929
*.egg
30+
*.pt
31+
*.pth
32+
*.pkl
33+
*.dat
3034
MANIFEST
3135

3236
# PyInstaller

docs/zh/blog/hstu_reproduction.md

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
## HSTU 模型在 torch-rechub 中的复现说明
2+
3+
本文件总结 torch-rechub 中对 Meta HSTU(Hierarchical Sequential Transduction Units)模型的复现情况,重点说明:
4+
5+
- 当前实现的整体架构与关键设计细节;
6+
- 与 Meta 官方开源实现/论文的一致之处;
7+
- 有意简化或仍然存在差异的部分。
8+
9+
---
10+
11+
## 1. 整体架构概览
12+
13+
### 1.1 模块划分
14+
15+
与 HSTU 相关的主要模块如下:
16+
17+
- **模型主体**`torch_rechub/models/generative/hstu.py`
18+
- `HSTUModel`:Embedding + HSTUBlock + 输出投影
19+
- **核心层与 Block**`torch_rechub/basic/layers.py`
20+
- `HSTULayer`:单层 HSTU 转导单元(多头注意力 + 门控 + FFN)
21+
- `HSTUBlock`:多层 HSTULayer 堆叠
22+
- **相对位置偏置与词表工具**`torch_rechub/utils/hstu_utils.py`
23+
- `RelPosBias``VocabMask``VocabMapper`
24+
- **时间感知数据预处理**`examples/generative/data/ml-1m/preprocess_ml_hstu.py`
25+
- **数据集与数据生成器**`torch_rechub/utils/data.py`
26+
- `SeqDataset``SequenceDataGenerator`
27+
- **训练与评估**
28+
- `torch_rechub/trainers/seq_trainer.py``SeqTrainer`
29+
- `examples/generative/run_hstu_movielens.py`:示例脚本、评估指标
30+
31+
### 1.2 数据与任务
32+
33+
- 数据集:MovieLens-1M `ratings.dat`(包含时间戳)
34+
- 任务形式:**Next-item prediction**(给定历史序列,预测下一个 item)
35+
- 训练目标:自回归式的 next-token 交叉熵损失(仅使用序列最后一个位置的 logits)
36+
- 评估指标:HR@K、NDCG@K(K=10, 50, 200)
37+
38+
---
39+
40+
## 2. HSTULayer 与 HSTUBlock 实现细节
41+
42+
### 2.1 HSTULayer:核心转导单元
43+
44+
`torch_rechub/basic/layers.py::HSTULayer` 实现了论文中的“Sequential Transduction Unit”核心思想:
45+
46+
1. **输入与线性投影**
47+
- 输入形状:`(B, L, D)`
48+
- 通过 `proj1: Linear(D → 2·H·dqk + 2·H·dv)` 同时产生 Q / K / U / V:
49+
- Q, K 形状:`(B, H, L, dqk)`
50+
- U, V 形状:`(B, H, L, dv)`
51+
52+
2. **多头自注意力 + causal mask**
53+
- 注意力打分:`scores = (Q @ K^T) / sqrt(dqk)`,形状 `(B, H, L, L)`
54+
- 使用严格的 **causal mask**:位置 i 只能看到 `≤ i` 的 token,防止未来信息泄露。
55+
- 可选加上相对位置偏置 `RelPosBias`
56+
- softmax 后得到 `attn_weights`,再与 V 相乘得到 `attn_output`
57+
58+
3. **门控机制(Gated Attention)**
59+
- 将注意力输出 `attn_output` 与门控向量 U 进行逐元素门控:
60+
- `gated_output = attn_output * sigmoid(U)`,形状 `(B, L, H·dv)`
61+
62+
4. **输出投影与残差 + FFN**
63+
- 使用 `proj2: Linear(H·dv → D)` 将多头输出还原到模型维度。
64+
- 两个残差块:
65+
1. 自注意力 + 门控 + 投影 + Dropout + 残差
66+
2. LayerNorm + FFN(4D) + Dropout + 残差
67+
- 使用 `LayerNorm` 做 pre-norm,提升深层训练稳定性。
68+
69+
### 2.2 HSTUBlock:多层堆叠
70+
71+
`HSTUBlock` 是多个 `HSTULayer` 的简单堆叠:
72+
73+
- 初始化时构建 `n_layers` 个 HSTULayer;
74+
- 前向传播中按顺序依次传递;
75+
- 未做层间不同窗口/不同参数共享的“显式层级结构”,这一点属于对论文中“Hierarchical”概念的工程化简化。
76+
77+
这一设计与 Meta 官方开源代码的风格一致:通过多层堆叠来实现逐层抽象的“层级”表示,而不是显式的多分辨率分支。
78+
79+
---
80+
81+
## 3. 时间戳建模与时间嵌入
82+
83+
### 3.1 数据预处理中的时间差计算
84+
85+
文件:`examples/generative/data/ml-1m/preprocess_ml_hstu.py`
86+
87+
核心设计:
88+
89+
- 对每个用户的交互序列,使用滑动窗口生成 `(history, target)` 样本:
90+
- history = 序列前缀;target = 当前 prefix 之后的一个 item;
91+
- 对于每个 history,计算 **相对于查询时间的时间差**
92+
- 查询时间 = history 中最后一个事件的时间戳 `query_timestamp`
93+
- 对每个历史事件 `ts`,时间差为 `query_timestamp - ts`
94+
- 例如时间戳 `[100, 200, 300, 400]` → 时间差 `[300, 200, 100, 0]`
95+
- 时间差以秒为单位保存为 `seq_time_diffs`,与 `seq_tokens` 同长;
96+
- 所有序列截断/左侧 padding 到固定长度 `max_seq_len`,padding 的时间差为 0。
97+
98+
这与 Meta 官方 HSTU 代码中 `query_time - timestamps` 的处理方式保持一致,而不是相邻事件时间间隔的形式。
99+
100+
### 3.2 模型中的时间嵌入与 bucket 化
101+
102+
文件:`torch_rechub/models/generative/hstu.py`
103+
104+
1. **时间嵌入表**
105+
- `self.time_embedding = nn.Embedding(num_time_buckets + 1, d_model, padding_idx=0)`
106+
- 其中 bucket 0 作为 padding bucket。
107+
108+
2. **时间差 → bucket 的映射**
109+
110+
```python
111+
# 伪代码
112+
# 1) 秒 → 分钟
113+
minutes = time_diffs.float() / 60.0
114+
# 2) 避免 log(0)
115+
minutes = clamp(minutes, min=1e-6)
116+
# 3) 按 sqrt 或 log 映射到 bucket
117+
if fn == 'sqrt':
118+
bucket = sqrt(minutes)
119+
elif fn == 'log':
120+
bucket = log(minutes)
121+
# 4) 截断到 [0, num_time_buckets-1]
122+
```
123+
124+
3. **嵌入融合与 Alpha 缩放**
125+
126+
- Token Embedding 使用 Alpha 缩放:`token_emb = token_embedding(x) * sqrt(d_model)`
127+
- Position Embedding 为标准的绝对位置嵌入;
128+
- Time Embedding 通过上述 bucket 索引查表得到;
129+
- 最终序列表示:`embeddings = token_emb + pos_emb + time_emb`
130+
131+
这部分在最近一次提交中完成了对 Meta 官方实现的细节对齐:
132+
133+
- 修复了时间差计算方式(由相邻间隔 → 与查询时间差);
134+
- 增加了 `/60.0` 的时间单位转换;
135+
- 增加了 `alpha = sqrt(d_model)` 的缩放。
136+
137+
---
138+
139+
## 4. 训练与评估流水线
140+
141+
### 4.1 SeqDataset 与 SequenceDataGenerator
142+
143+
文件:`torch_rechub/utils/data.py`
144+
145+
- 近期提交中已**移除旧 3 元组格式的向后兼容逻辑**,统一为 4 元组:
146+
- `(seq_tokens, seq_positions, seq_time_diffs, targets)`
147+
- `SeqDataset` 负责将 NumPy 数组转换为 PyTorch 张量;
148+
- `SequenceDataGenerator` 根据给定的 train/val/test 划分构造 DataLoader。
149+
150+
### 4.2 SeqTrainer:训练与评估
151+
152+
文件:`torch_rechub/trainers/seq_trainer.py`
153+
154+
- `train_one_epoch`
155+
- 输入 batch 形如:`(seq_tokens, seq_positions, seq_time_diffs, targets)`
156+
- 将张量移动到设备;
157+
- 调用 `model(seq_tokens, seq_time_diffs)` 得到 `(B, L, V)` logits;
158+
- 只取最后一个位置 `logits[:, -1, :]``targets` 做交叉熵损失;
159+
- `evaluate`
160+
- 与训练阶段类似,同样只使用序列最后一个位置;
161+
- 统计平均 loss 与 top-1 准确率,用于早停与模型选择。
162+
163+
### 4.3 示例脚本与推荐指标
164+
165+
文件:`examples/generative/run_hstu_movielens.py`
166+
167+
- 负责加载预处理好的 MovieLens 数据(真实数据),构造数据加载器与模型;
168+
- 使用 `SeqTrainer` 进行训练与验证;
169+
- `evaluate_ranking` 函数在测试集上计算 HR@K 与 NDCG@K:
170+
- 模型同样使用最后一个位置的 logits;
171+
- 对所有候选 item 排序,计算 top-K 命中率与折损累计增益。
172+
173+
近期在修复时间戳处理逻辑后,测试集指标相比旧实现有显著提升(以 K=10 为例):
174+
175+
- HR@10:约从 0.17 提升到 0.21+
176+
- NDCG@10:约从 0.08 提升到 0.11+
177+
178+
这表明时间衰减建模对生成式推荐效果有明显正向作用。
179+
180+
---
181+
182+
## 5. 与 Meta 官方实现的一致性与差异
183+
184+
### 5.1 主要一致点
185+
186+
与 Meta 官方 HSTU / DLRM-HSTU 实现相比,本框架在以下方面保持较高一致性:
187+
188+
- **核心层结构**:HSTULayer 采用 Q/K/V/U 四路线性投影、多头注意力、门控机制与两段残差 FFN,结构上与官方实现高度一致;
189+
- **因果掩码**:在注意力打分阶段使用严格的 causal mask,保证生成式任务的因果性;
190+
- **时间差定义**:使用 `query_time - timestamps` 形式的时间差,而非相邻事件间隔;
191+
- **时间 bucket 化与嵌入**:支持 sqrt/log 两种 bucket 映射,配合时间嵌入表,与官方思路对齐;
192+
- **Alpha 缩放**:对 token embedding 乘以 `sqrt(d_model)`,与官方实现中的缩放策略一致;
193+
- **训练目标**:自回归式的 next-item 交叉熵目标,等价于语言模型式训练。
194+
195+
### 5.2 主要差异与简化
196+
197+
目前实现仍有以下差异或有意简化:
198+
199+
1. **未包含 DLRM 与多任务头**
200+
- 官方 DLRM-HSTU 实现支持复杂的特征交叉与多任务学习;
201+
- 本框架专注于单任务的 next-item prediction,未实现 DLRM 部分与多目标头。
202+
203+
2. **相对位置偏置为简化版本**
204+
- 当前的 `RelPosBias` 基于 `|i - j|` 距离做线性分桶;
205+
- 未显式区分方向(正/负距离)、也未使用更复杂的 log-scaling bucket 公式;
206+
- 这在工程上更简单,但与官方实现存在细节差异。
207+
208+
3. **仅提供单步 next-item 预测接口**
209+
- 训练和评估阶段都是“给定完整历史 → 预测下一个 item”;
210+
- 尚未封装多步自回归解码接口(如 beam search 生成未来 N 步序列);
211+
- 对于大多数推荐 benchmark(只评估下一步)已经足够,但与“通用生成式序列模型”相比功能较少。
212+
213+
4. **部分初始化细节不同**
214+
- 当前使用 `xavier_uniform_` 初始化大部分线性层和嵌入;
215+
- 官方实现中某些嵌入可能使用基于维度的 `uniform(-sqrt(1/N), sqrt(1/N))`
216+
- 这类初始化差异对最终收敛影响有限,但不是 100% bit-level 复现。
217+
218+
---
219+
220+
## 6. 近期提交总结
221+
222+
- 引入了 HSTU 模型、HSTULayer/HSTUBlock、SeqTrainer、SeqDataset 等完整骨架;
223+
- 实现了基本的生成式 next-item 训练与评估流程;
224+
- 时间戳处理、时间嵌入与部分细节尚处于初版实现阶段。
225+
226+
- 重构 MovieLens 预处理脚本:
227+
- 使用滑动窗口策略大幅增加训练样本;
228+
- 按用户划分 train/val/test,避免数据泄漏;
229+
- 正确使用 `query_time - timestamps` 形式的时间差;
230+
- 修复时间嵌入实现:
231+
- 添加秒 → 分钟的时间单位转换;
232+
- 增加 `alpha = sqrt(d_model)` 缩放;
233+
- 与官方时间建模逻辑对齐;
234+
- 清理向后兼容逻辑:
235+
- 移除 3 元组数据格式,统一为 4 元组 `(tokens, positions, time_diffs, targets)`
236+
- 简化 SeqDataset、SequenceDataGenerator、SeqTrainer 代码结构;
237+
- 训练与评估结果显示所有排名指标均有显著提升,验证了时间建模修复的必要性和有效性。
238+
239+
---
240+
241+
## 7. 小结
242+
243+
- 当前实现已经在 **HSTU 核心层结构、时间建模与训练目标** 上与 Meta 官方实现高度对齐;
244+
- 同时刻意简化了 DLRM、多任务头、复杂特征工程等工程部分,使得该实现更适合作为研究和教学的参考版本;
245+
- 如果后续需要进一步逼近“论文级完全复现”,推荐优先完善:
246+
1. RelPosBias 的 bucket 公式与方向建模;
247+
2. padding mask 的显式支持;
248+
3. 多步自回归解码接口与更复杂的下游任务场景。
249+

0 commit comments

Comments
 (0)