Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about custom cuda operators for tensor parallelism #434

Open
vermouth1992 opened this issue Jun 28, 2024 · 5 comments
Open

Question about custom cuda operators for tensor parallelism #434

vermouth1992 opened this issue Jun 28, 2024 · 5 comments
Labels
question Further information is requested

Comments

@vermouth1992
Copy link

We are currently trying to apply torchtitan to MoE models. MoE models require using grouped_gemm https://github.com/fanshiqing/grouped_gemm. GroupedGemm ops basically follow the same rule as in ColumnLinear and RowLinear. Is there any way to make custom ops dtensor compatible? Great thanks for help!

@yifuwang
Copy link
Contributor

I think there are two high-level approaches:

I think both approaches are equally viable if the permute input and output have the same sharding. If you are performing expert parallelism across devices, the second approach is probably easier at the moment.

@fegin
Copy link
Contributor

fegin commented Jun 28, 2024

For the second method mentioned by @yifuwang, both ColwiseParallel and RowwiseParalle have the options to convert the output to local tensors. Both also convert the input tensor to DTensor if the inputs are torch.Tensor. So you will be able to combine non-DTensor compatible ops with ColwiseParallel and RowwiseParalle

@awgu awgu added the question Further information is requested label Jun 29, 2024
@kwen2501
Copy link
Contributor

Curious, regarding this figure in your repo, which value(s) do you want to DTensor'ize? The input? Each of the Experts? Or representing the whole group of experts as a big DTensor? (each expert being a subset of rows or columns of it). Would appreciate your comment.

@vermouth1992
Copy link
Author

vermouth1992 commented Jun 30, 2024

We would like both the input and the expert weights to be in DTensor just like the FFN layer in dense models. Specially, the gate can be parallelized via SequenceParallel. Each expert can use the standard Column + Row Parallel. However, this is less efficient in computation as we have to loop over all the experts. By using GroupedGemm, the experts can be concatenated together and we just need to perform a single big GroupedGemm.

The weights of the whole experts.

  • w1/w3: [num_experts, hidden_size, ffn_size] -> [num_experts, hidden_size, ffn_size // tp]
  • w2: [num_experts, ffn_size, hidden_size] -> [num_experts, ffn_size // tp, hidden_size]

We can use similar techniques as in ColumnParallel for w1, w3 and RowParallel for w2 (essentially a batched version). But we need GroupedGemm operators to know the TP sharding propagation rule as in torch.matmul.

Note that this is still tensor parallelism. We can further perform expert parallelism by distributing experts onto different Expert Parallel Group. This further shards the weights to

  • w1/w3: [num_experts, hidden_size, ffn_size] -> [num_experts // ep, hidden_size, ffn_size // tp]
  • w2: [num_experts, ffn_size, hidden_size] -> [num_experts // ep, ffn_size // tp, hidden_size]

In expert parallelism, we simply distribute the tokens corresponding to that particular experts to that EP group and discards the rest. The inputs can be hard to express in DTensor as it is uneven sharding in EP group. But inside each EP group, we can use DTensor to represent the inputs.

The implementation in Megatron is a good reference with raw torch.Tensor and torch.distributed, (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/experts.py#L25).

@XilunWu
Copy link
Contributor

XilunWu commented Jul 1, 2024

You can also try DTensor local_map as how we enabled FusedRMSNorm in torchtitan: #404 , which is the second approach in @yifuwang 's comment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

6 participants