-
Notifications
You must be signed in to change notification settings - Fork 15
Add multicast tensor #346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add multicast tensor #346
Conversation
stack-info: PR: #346, branch: joydddd/stack/17
1e986c5
to
0bcfcca
Compare
0bcfcca
to
5609bbf
Compare
stack-info: PR: #346, branch: joydddd/stack/17
stack-info: PR: #346, branch: joydddd/stack/17
5609bbf
to
1749db5
Compare
from .._compiler.variable_origin import Origin | ||
|
||
|
||
class MulticastTensor(NamedTuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't inherit from NamedTuple does this still work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, removed the inheritation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, no, we might want to keep the NamedTuple inheritation. fake_fn
for hl ops is called during both type propagation and device_ir tracing. For type propagation multicast tensors are passed in the original MulticastTensor type and in device_ir tracing we call prepare_args to unpack the MulticastTensor into tuples before calling fake_fn. It is nicer make MulticastTensor a NamedTuple so that in both cases it is a tuple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless we find a better way to deal with MulticastTensor constructor in device_ir, and we don't need to unpack it into a tuple.
1749db5
to
fd02b59
Compare
fd02b59
to
ba94f3f
Compare
@@ -289,6 +295,134 @@ def codegen_store( | |||
) | |||
|
|||
|
|||
class MulticastIndexingStrategy: | |||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add more detail on the semantics of this indexing strategy
return state.device_function.indexing_strategy.codegen_store( | ||
state, tensor, [*subscript], value, extra_mask | ||
) | ||
if isinstance(tensor, tuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't totally follow why we convert to tuple instead of keeping as multicast tensor type feels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MulticastTensor
constructor is not a device function and should not show up in the deviceIR fx graph (we have no lowering path for it). Therefore to avoid our tracer seeing that, we unpack it to a tuple entering the tracer.
Any ideas on how to do this in a nicer way to avoid type checking for a tuple? Maybe add a typename item in the tuple and check for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will multicast tensors work with arbitrary operations (for example inductor ones)?
Are the semantics always that we repeat an op once for every tensor?
I worry that needing to add an if multicast
case to every op will add a lot of complexity. Could we implement this as an FX graph pass that duplicates every op once per each sub-tensor?
def has_multicast_tensor_with_rdim(self, graph: torch.fx.Graph) -> bool: | ||
"""Check if a graph contains multicast tensors with rdim inputs.""" | ||
|
||
def is_multicast_with_rdim(node: torch.fx.Node) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to global scope?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what do you mean by moving to global scope? We check if self.rdim
of the roller matches any of the multicasted dims. @yf225 can you take a look?
tensor_like: torch.Tensor | ||
dev_ptrs: torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you talk about the motivation for this representation?
Multicast tensors are meant to a host tensors that can only be accessed by memory operations (namely hl.load, hl.store, hl.atomic_add, hl.signal & hl.wait). For these cases we are manually handling indexing, masks and offset calculation. Arbitrary operations (for example inductor ones) applies only to device tensors, return from these memory operations. Any pointers to where I should check to make sure Multicast Tensors are not used for non memory ops?
![]() Primary use case for this is shared tensor on symmetric memory, i.e. each device hold a version of I'm considering adding more APIs to create MulticastTensors: e.g. Future extensions: Multimem Pointer: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=multimem#data-movement-and-conversion-instructions-multimem could be an alternative backend to MulticastTensor than dev_ptrs tensor. However, currently Triton does not support that and we don't know the interface. |
But if you load a multicast tensor, don't you get a multicast device tensor? (Which you can run ops on.) |
ba94f3f
to
7cc53a9
Compare
7cc53a9
to
bf0db57
Compare
Ah I misunderstood. How does this work with multiple dimensions? Can you have a 3D multicast tensor? Or are we requiring 1D tensors on each device, so the entire thing acts like a 2D tensor. Does the same thing happen in reverse on store? |
It works with any number of dimensions for dev_pts & example tensor.
Yep, happens reverse on store. For 1D dev_prts & 1D example tensor tile, the store value must be 2D. |
stack-info: PR: #346, branch: joydddd/stack/17
bf0db57
to
8ee7615
Compare
def test_multicast_load_2d_tensors(self): | ||
@helion.kernel | ||
def multicast_load_kernel( | ||
dev_ptrs: torch.Tensor, | ||
example_tensor: torch.Tensor, | ||
) -> torch.Tensor: | ||
M = dev_ptrs.size(0) | ||
N1, N2 = example_tensor.size() | ||
out = torch.empty(M, N1, N2, dtype=torch.bfloat16, device=dev_ptrs.device) | ||
|
||
for tile1, tile2 in hl.tile([N1, N2]): | ||
ptr_tile = dev_ptrs[:] | ||
tensors = hl.multicast_like(example_tensor, ptr_tile) | ||
out[:, tile1, tile2] = tensors[tile1, tile2] | ||
return out | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e.g. tensors are 2D.
def test_multicast_load_2d_dev_ptrs(self): | ||
@helion.kernel | ||
def multicast_load_kernel_2d( | ||
dev_ptrs: torch.Tensor, | ||
example_tensor: torch.Tensor, | ||
) -> torch.Tensor: | ||
M1, M2 = dev_ptrs.size() | ||
N = example_tensor.size(0) | ||
out = torch.empty(M1, M2, N, dtype=torch.bfloat16, device=dev_ptrs.device) | ||
|
||
for tile in hl.tile(N, block_size=4): | ||
ptr_tile = dev_ptrs[:, :] | ||
tensors = hl.multicast_like(example_tensor, ptr_tile) | ||
out[:, :, tile] = tensors[tile] | ||
return out | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e.g. dev_ptrs are 2D
x_tile = x[tile] | ||
tensors[tile] = x_tile[None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadcast x_tile
to all tensors at store.
x = hl.arange(M) | ||
tensors[i] = x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do reverse of stack at store
Stacked PRs:
Add multicast tensor