Skip to content

Commit 09d1ff2

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
benchmark of fbgemm op - regroup_kts (#2159)
Summary: Pull Request resolved: #2159 # context * added **fn-level** benchmark for the `regroup_keyed_tensor` * `keyed_tensor_regroup` further reduces the CPU runtime from 2.0ms to 1.3ms (35% improvement) without hurting the GPU runtime/memory usage # conclusion * CPU runtime **reduces 40%** from 1.8 ms to 1.1 ms * GPU runtime **reduces 60%** from 4.9 ms to 2.0 ms * GPU memory **reduces 33%** from 1.5 K to 1.0 K * **we should migrate to the new op** unless any unknown concern/blocker # traces * [files](https://drive.google.com/drive/folders/1iiEf30LeG_i0xobMZVhmMneOQ5slmX3U?usp=drive_link) ``` [[email protected] /data/sandcastle/boxes/fbsource (04ad34da3)]$ ll *.json -rw-r--r-- 1 hhy hhy 552501 Jul 10 16:01 'trace-[1 Op] KT_regroup_dup.json' -rw-r--r-- 1 hhy hhy 548847 Jul 10 16:01 'trace-[1 Op] KT_regroup.json' -rw-r--r-- 1 hhy hhy 559006 Jul 10 16:01 'trace-[2 Ops] permute_multi_embs_dup.json' -rw-r--r-- 1 hhy hhy 553199 Jul 10 16:01 'trace-[2 Ops] permute_multi_embs.json' -rw-r--r-- 1 hhy hhy 5104239 Jul 10 16:01 'trace-[Module] KTRegroupAsDict_dup.json' -rw-r--r-- 1 hhy hhy 346643 Jul 10 16:01 'trace-[Module] KTRegroupAsDict.json' -rw-r--r-- 1 hhy hhy 895096 Jul 10 16:01 'trace-[Old Prod] permute_pooled_embs.json' -rw-r--r-- 1 hhy hhy 561685 Jul 10 16:01 'trace-[Prod] KeyedTensor.regroup_dup.json' -rw-r--r-- 1 hhy hhy 559147 Jul 10 16:01 'trace-[Prod] KeyedTensor.regroup.json' -rw-r--r-- 1 hhy hhy 7958676 Jul 10 16:01 'trace-[pytorch generic] fallback_dup.json' -rw-r--r-- 1 hhy hhy 7978141 Jul 10 16:01 'trace-[pytorch generic] fallback.json' ``` * pytorch generic {F1755208341} * current prod {F1755209251} * permute_multi_embedding (2 Ops) {F1755210682} * KT.regroup (1 Op) {F1755210008} * regroupAsDict (Module) {F1755210990} * metrics |Operator|CPU runtime|GPU runtime|GPU memory|notes| |---|---|---|---|---| |**[fallback] pytorch generic**|3.9 ms|3.2 ms|1.0 K|CPU-bounded, allow duplicates| |**[prod] _fbgemm_permute_pooled_embs**|1.9 ms|4.9 ms|1.5 K|GPU-boudned, does **NOT** allow duplicates, PT2 non-compatible `pin_and_move`| |**[hybrid python/cu] keyed_tensor_regroup**|1.5 ms|2.0 ms|1.0 K|both GPU runtime and memory improved, **ALLOW** duplicates, PT2 friendly| |**[pure c++/cu] permute_multi_embedding**|1.0 ms|2.0 ms|1.0 K|both CPU and GPU runtime/memory improved, **ALLOW** duplicates, PT2 friendly| Reviewed By: dstaay-fb Differential Revision: D58907223 fbshipit-source-id: 108ce355b9191cba6fe6a79e54dc7291b8463f7b
1 parent 4f114bc commit 09d1ff2

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,19 @@ def permute_multi_embedding(
188188
return permuted_values
189189

190190

191+
@torch.fx.wrap
192+
def regroup_kts(
193+
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
194+
) -> List[torch.Tensor]:
195+
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
196+
return torch.ops.fbgemm.regroup_keyed_tensor(
197+
values,
198+
keys,
199+
lengths,
200+
groups,
201+
)
202+
203+
191204
@torch.fx.wrap
192205
def _fbgemm_permute_pooled_embs(
193206
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]

torchrec/sparse/tests/jagged_tensor_benchmark.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
from torchrec.distributed.benchmark.benchmark_utils import benchmark, BenchmarkResult
1919
from torchrec.modules.regroup import KTRegroupAsDict
2020
from torchrec.sparse.jagged_tensor import (
21+
_fbgemm_permute_pooled_embs,
2122
_regroup_keyed_tensors,
2223
KeyedJaggedTensor,
2324
KeyedTensor,
2425
permute_multi_embedding,
26+
regroup_kts,
2527
)
2628
from torchrec.sparse.tests.utils import build_groups, build_kts
2729

@@ -213,7 +215,7 @@ def main(
213215
).float()
214216
groups = build_groups(kts, n_groups, duplicates=duplicates)
215217
bench(
216-
"_regroup_keyed_tenors" + dup,
218+
"[pytorch generic] fallback" + dup,
217219
labels,
218220
batch_size,
219221
n_dense + n_sparse,
@@ -224,7 +226,7 @@ def main(
224226
profile,
225227
)
226228
bench(
227-
"KeyedTensor.regroup" + dup,
229+
"[Prod] KeyedTensor.regroup" + dup,
228230
labels,
229231
batch_size,
230232
n_dense + n_sparse,
@@ -235,7 +237,7 @@ def main(
235237
profile,
236238
)
237239
bench(
238-
"KTRegroupAsDict" + dup,
240+
"[Module] KTRegroupAsDict" + dup,
239241
labels,
240242
batch_size,
241243
n_dense + n_sparse,
@@ -248,7 +250,7 @@ def main(
248250
profile,
249251
)
250252
bench(
251-
"permute_multi_embs" + dup,
253+
"[2 Ops] permute_multi_embs" + dup,
252254
labels,
253255
batch_size,
254256
n_dense + n_sparse,
@@ -258,6 +260,29 @@ def main(
258260
{"keyed_tensors": kts, "groups": groups},
259261
profile,
260262
)
263+
bench(
264+
"[1 Op] KT_regroup" + dup,
265+
labels,
266+
batch_size,
267+
n_dense + n_sparse,
268+
device_type,
269+
run_backward,
270+
regroup_kts,
271+
{"keyed_tensors": kts, "groups": groups},
272+
profile,
273+
)
274+
if not duplicates:
275+
bench(
276+
"[Old Prod] permute_pooled_embs" + dup,
277+
labels,
278+
batch_size,
279+
n_dense + n_sparse,
280+
device_type,
281+
run_backward,
282+
_fbgemm_permute_pooled_embs,
283+
{"keyed_tensors": kts, "groups": groups},
284+
profile,
285+
)
261286

262287

263288
if __name__ == "__main__":

0 commit comments

Comments
 (0)