Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about Pipeline parallelism #431

Open
vermouth1992 opened this issue Jun 27, 2024 · 5 comments
Open

Question about Pipeline parallelism #431

vermouth1992 opened this issue Jun 27, 2024 · 5 comments
Assignees
Labels
question Further information is requested

Comments

@vermouth1992
Copy link

Just wonder does the current PipelineStage API supports variable length input shapes like in Megatron? https://github.com/NVIDIA/Megatron-LM/blob/e33c8f78a35765d5aa37475a144da60e8a2349d1/megatron/core/model_parallel_config.py#L212 This is particular useful for packed inputs where all the paddings are removed.

@tianyu-l tianyu-l added the question Further information is requested label Jun 27, 2024
@H-Huang
Copy link
Member

H-Huang commented Jun 27, 2024

The PipelineStage API currently does not support variable length input shapes. The shape is static and based off of the input_args argument. We have an open issue which tracks a similar feature (lazy shape inference) that which could potentially be used to support variable length inputs (pytorch/pytorch#126216).

I wanted to understand more about the use case though. Can you elaborate on This is particular useful for packed inputs where all the paddings are removed, is there an example and how commonly is this used?

@vermouth1992
Copy link
Author

vermouth1992 commented Jun 28, 2024

The current models implemented in huggingface transformers library use input_ids and attention_mask as inputs. For SFT and RLHF, a large portion of the inputs will be padding. Computing those paddings through MLP and attention is a huge waste. Instead, we can unpad the data (e.g. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py#L98) so that all the sequences inside a batch is packed and can be separated via cu_seqlens.
For MLP and layernorm, we don't need to change the model as there is no inter-token computation. For attention, flash-attention supports passing packed inputs. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1005
In this case, every micro-batch will have a different input shapes during each forward. Currently, this can be workaround by passing padded input_ids and attention_masks and call unpad inside the model of each pipeline stage, with some loss of performance.

@wconstab
Copy link
Contributor

The root issue here is that a recv operation needs to know its size before starting. So every stage besides 0 needs to know the size of the new microbatch before passing the data down the pipeline.

A couple ways this can be done-

  1. We require every pipeline stage to receive data from the data loader locally. This puts extra load on the data loader but may be ok. Then we need some way to infer the input shape of the current stage as a function of the microbatch shape. Then we can directly allocate recv buffers for each batch.
  2. We add an extra communication step into our schedules that passes the output shape from our forward to the next stage, which then allocates a buffer and performs recv. This extra communication will add overhead not only due to nccl time, but maybe also additional bubble time if it prevents us starting the recv operation as early.

In the current situation, you pay extra communication cost for padding, but then you unpad the data before using it. Is that correct? We'd have to understand how bad this overhead is, compared to the above proposals and their overhead and complexity. Do you have some data?

@kwen2501
Copy link
Contributor

kwen2501 commented Jul 2, 2024

The Megatron code has the following comment:

Setting this communicates the size of tensors during pipeline parallelism communication, because of this extra overhead it should only be set if the sequence length varies by microbatch within a global batch.

My understanding of the above comment is that the microbatch size is still constant , but the number of microbatches could change.

Does that apply to your case?

If it applies, it should be doable by reconfiguring the num_microbatches of PipelineSchedule dynamically. (We don't provide such an API now but it should be doable.)

The other challenge is making such reconfiguration on all pipeline ranks because they need to agree on it -- do you think you could make such API call on all ranks? Otherwise, as the Megatron comment suggests, there will be an overhead for rank 0 to inform all other ranks.

Cc @wconstab @H-Huang

@vermouth1992
Copy link
Author

Actually, the microbatch size varies inside a batch across microbatches. In Megatron, the size of the output of each TransformerLayer is [total_nnz, hidden_size], which total_nnz is the total number of tokens inside the microbatch after removing all the paddings.

Assume we have a batch of size [batch_size, seqlen, hidden_size]. We truncate in the batch dimension into a list of microbatches [batch_size // n_microbatches, seqlen, hidden_size]. Then, we perform unpad operations, and it becomes [total_nnz, hidden_size], which total_nnz is the total number of tokens inside the microbatch after removing all the paddings.

Support for variable sequence lengths across microbatches. Setting this communicates the size of tensors during pipeline parallelism communication, because of this extra overhead it should only be set if the sequence length varies by microbatch within a global batch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants