forked from thinking-machines-lab/batch_invariant_ops
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_batch_invariance.py
More file actions
36 lines (28 loc) · 1.1 KB
/
test_batch_invariance.py
File metadata and controls
36 lines (28 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from batch_invariant_ops import set_batch_invariant_mode
torch.set_default_device('cuda')
# Just to get the logging out of the way haha
with set_batch_invariant_mode(True):
pass
def test_batch_invariance():
B, D = 2048, 4096
a = torch.linspace(-100, 100, B*D).reshape(B, D)
b = torch.linspace(-100, 100, D*D).reshape(D, D)
# Method 1: Matrix-vector multiplication (batch size 1)
out1 = torch.mm(a[:1], b)
# Method 2: Matrix-matrix multiplication, then slice (full batch)
out2 = torch.mm(a, b)[:1]
# Check if results are identical
diff = (out1 - out2).abs().max()
print(f"Difference: {diff.item()}")
return diff.item() == 0
# Test with standard PyTorch (likely to show differences)
print("Standard PyTorch:")
with set_batch_invariant_mode(False):
is_deterministic = test_batch_invariance()
print(f"Deterministic: {is_deterministic}")
# Test with batch-invariant operations
print("\nBatch-Invariant Mode:")
with set_batch_invariant_mode(True):
is_deterministic = test_batch_invariance()
print(f"Deterministic: {is_deterministic}")