Skip to content

Add multicast tensor #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Add multicast tensor #346

wants to merge 1 commit into from

Conversation

joydddd
Copy link
Contributor

@joydddd joydddd commented Jul 22, 2025

joydddd added a commit that referenced this pull request Jul 22, 2025
stack-info: PR: #346, branch: joydddd/stack/17
@joydddd joydddd force-pushed the joydddd/stack/17 branch from 1e986c5 to 0bcfcca Compare July 22, 2025 02:55
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 22, 2025
@joydddd joydddd force-pushed the joydddd/stack/17 branch from 0bcfcca to 5609bbf Compare July 22, 2025 02:57
joydddd added a commit that referenced this pull request Jul 22, 2025
stack-info: PR: #346, branch: joydddd/stack/17
@joydddd joydddd marked this pull request as ready for review July 22, 2025 04:03
@joydddd joydddd requested review from jansel, yf225, drisspg and oulgen and removed request for jansel, yf225, drisspg and oulgen July 22, 2025 04:04
@joydddd joydddd marked this pull request as draft July 22, 2025 04:07
joydddd added a commit that referenced this pull request Jul 22, 2025
stack-info: PR: #346, branch: joydddd/stack/17
@joydddd joydddd force-pushed the joydddd/stack/17 branch from 5609bbf to 1749db5 Compare July 22, 2025 19:11
@joydddd joydddd requested review from jansel, oulgen, yf225 and drisspg and removed request for jansel July 22, 2025 19:11
@joydddd joydddd marked this pull request as ready for review July 22, 2025 19:11
from .._compiler.variable_origin import Origin


class MulticastTensor(NamedTuple):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't inherit from NamedTuple does this still work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, removed the inheritation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, no, we might want to keep the NamedTuple inheritation. fake_fn for hl ops is called during both type propagation and device_ir tracing. For type propagation multicast tensors are passed in the original MulticastTensor type and in device_ir tracing we call prepare_args to unpack the MulticastTensor into tuples before calling fake_fn. It is nicer make MulticastTensor a NamedTuple so that in both cases it is a tuple.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless we find a better way to deal with MulticastTensor constructor in device_ir, and we don't need to unpack it into a tuple.

@@ -289,6 +295,134 @@ def codegen_store(
)


class MulticastIndexingStrategy:
@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more detail on the semantics of this indexing strategy

return state.device_function.indexing_strategy.codegen_store(
state, tensor, [*subscript], value, extra_mask
)
if isinstance(tensor, tuple):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't totally follow why we convert to tuple instead of keeping as multicast tensor type feels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MulticastTensor constructor is not a device function and should not show up in the deviceIR fx graph (we have no lowering path for it). Therefore to avoid our tracer seeing that, we unpack it to a tuple entering the tracer.

Any ideas on how to do this in a nicer way to avoid type checking for a tuple? Maybe add a typename item in the tuple and check for that?

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will multicast tensors work with arbitrary operations (for example inductor ones)?

Are the semantics always that we repeat an op once for every tensor?

I worry that needing to add an if multicast case to every op will add a lot of complexity. Could we implement this as an FX graph pass that duplicates every op once per each sub-tensor?

def has_multicast_tensor_with_rdim(self, graph: torch.fx.Graph) -> bool:
"""Check if a graph contains multicast tensors with rdim inputs."""

def is_multicast_with_rdim(node: torch.fx.Node) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to global scope?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what do you mean by moving to global scope? We check if self.rdim of the roller matches any of the multicasted dims. @yf225 can you take a look?

Comment on lines +37 to +40
tensor_like: torch.Tensor
dev_ptrs: torch.Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you talk about the motivation for this representation?

@joydddd
Copy link
Contributor Author

joydddd commented Jul 23, 2025

Will multicast tensors work with arbitrary operations (for example inductor ones)?

Are the semantics always that we repeat an op once for every tensor?

I worry that needing to add an if multicast case to every op will add a lot of complexity. Could we implement this as an FX graph pass that duplicates every op once per each sub-tensor?

Multicast tensors are meant to a host tensors that can only be accessed by memory operations (namely hl.load, hl.store, hl.atomic_add, hl.signal & hl.wait). For these cases we are manually handling indexing, masks and offset calculation. Arbitrary operations (for example inductor ones) applies only to device tensors, return from these memory operations.

Any pointers to where I should check to make sure Multicast Tensors are not used for non memory ops?

Can you talk about the motivation for this representation?

 tensor_like: torch.Tensor
    dev_ptrs: torch.Tensor
image

Primary use case for this is shared tensor on symmetric memory, i.e. each device hold a version of a_shared tensor. MulticastTensor creates a virtual concatenation of these tensors without realizing a copy. "Symmetric" means each distinct tensor has the same shape, stride & dtype, and therefore same index calculation. tensor_like is the example tensor used for indexing. In the symmetric memory case, this is usually the local tensor.

I'm considering adding more APIs to create MulticastTensors: e.g. multicast(dev_ptr, shape, dtype, stride, ...) and create an empty (fake) host tensor inside Helion as the indexing reference.

Future extensions: Multimem Pointer: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=multimem#data-movement-and-conversion-instructions-multimem could be an alternative backend to MulticastTensor than dev_ptrs tensor. However, currently Triton does not support that and we don't know the interface.

@jansel
Copy link
Contributor

jansel commented Jul 23, 2025

But if you load a multicast tensor, don't you get a multicast device tensor? (Which you can run ops on.)

@joydddd
Copy link
Contributor Author

joydddd commented Jul 23, 2025

image

No, hl.load(multicast, index... ) return a device tensor (or kernel tensor? something living in registers / smem) of type torch.Tensor. A slice is load from each buffer and stacked together to make the hl.load return value. In this case multicast_t[1] is a normal 1D tensor of shape (num_buffers, ), and you can do all normal tensor ops on it.

@jansel
Copy link
Contributor

jansel commented Jul 25, 2025

Ah I misunderstood.

How does this work with multiple dimensions? Can you have a 3D multicast tensor? Or are we requiring 1D tensors on each device, so the entire thing acts like a 2D tensor.

Does the same thing happen in reverse on store?

@joydddd
Copy link
Contributor Author

joydddd commented Jul 25, 2025

How does this work with multiple dimensions? Can you have a 3D multicast tensor? Or are we requiring 1D tensors on each device, so the entire thing acts like a 2D tensor.

It works with any number of dimensions for dev_pts & example tensor.
Say the dev_ptrs are 2D, and the example tensor is 3D. Then the entire thing acts like 2D + 3D -> 5D tensor.

Does the same thing happen in reverse on store?

Yep, happens reverse on store. For 1D dev_prts & 1D example tensor tile, the store value must be 2D.

stack-info: PR: #346, branch: joydddd/stack/17
@joydddd joydddd force-pushed the joydddd/stack/17 branch from bf0db57 to 8ee7615 Compare July 25, 2025 17:49
Comment on lines +43 to +58
def test_multicast_load_2d_tensors(self):
@helion.kernel
def multicast_load_kernel(
dev_ptrs: torch.Tensor,
example_tensor: torch.Tensor,
) -> torch.Tensor:
M = dev_ptrs.size(0)
N1, N2 = example_tensor.size()
out = torch.empty(M, N1, N2, dtype=torch.bfloat16, device=dev_ptrs.device)

for tile1, tile2 in hl.tile([N1, N2]):
ptr_tile = dev_ptrs[:]
tensors = hl.multicast_like(example_tensor, ptr_tile)
out[:, tile1, tile2] = tensors[tile1, tile2]
return out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. tensors are 2D.

Comment on lines +73 to +88
def test_multicast_load_2d_dev_ptrs(self):
@helion.kernel
def multicast_load_kernel_2d(
dev_ptrs: torch.Tensor,
example_tensor: torch.Tensor,
) -> torch.Tensor:
M1, M2 = dev_ptrs.size()
N = example_tensor.size(0)
out = torch.empty(M1, M2, N, dtype=torch.bfloat16, device=dev_ptrs.device)

for tile in hl.tile(N, block_size=4):
ptr_tile = dev_ptrs[:, :]
tensors = hl.multicast_like(example_tensor, ptr_tile)
out[:, :, tile] = tensors[tile]
return out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. dev_ptrs are 2D

Comment on lines +198 to +199
x_tile = x[tile]
tensors[tile] = x_tile[None, :]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadcast x_tile to all tensors at store.

Comment on lines +230 to +231
x = hl.arange(M)
tensors[i] = x
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do reverse of stack at store

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants