-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about Pipeline parallelism #431
Comments
The I wanted to understand more about the use case though. Can you elaborate on |
The current models implemented in huggingface transformers library use input_ids and attention_mask as inputs. For SFT and RLHF, a large portion of the inputs will be padding. Computing those paddings through MLP and attention is a huge waste. Instead, we can unpad the data (e.g. https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py#L98) so that all the sequences inside a batch is packed and can be separated via cu_seqlens. |
The root issue here is that a recv operation needs to know its size before starting. So every stage besides 0 needs to know the size of the new microbatch before passing the data down the pipeline. A couple ways this can be done-
In the current situation, you pay extra communication cost for padding, but then you unpad the data before using it. Is that correct? We'd have to understand how bad this overhead is, compared to the above proposals and their overhead and complexity. Do you have some data? |
The Megatron code has the following comment:
My understanding of the above comment is that the microbatch size is still constant , but the number of microbatches could change. Does that apply to your case? If it applies, it should be doable by reconfiguring the num_microbatches of PipelineSchedule dynamically. (We don't provide such an API now but it should be doable.) The other challenge is making such reconfiguration on all pipeline ranks because they need to agree on it -- do you think you could make such API call on all ranks? Otherwise, as the Megatron comment suggests, there will be an overhead for rank 0 to inform all other ranks. |
Actually, the microbatch size varies inside a batch across microbatches. In Megatron, the size of the output of each TransformerLayer is [total_nnz, hidden_size], which total_nnz is the total number of tokens inside the microbatch after removing all the paddings. Assume we have a batch of size [batch_size, seqlen, hidden_size]. We truncate in the batch dimension into a list of microbatches [batch_size // n_microbatches, seqlen, hidden_size]. Then, we perform unpad operations, and it becomes [total_nnz, hidden_size], which total_nnz is the total number of tokens inside the microbatch after removing all the paddings.
|
Just wonder does the current PipelineStage API supports variable length input shapes like in Megatron? https://github.com/NVIDIA/Megatron-LM/blob/e33c8f78a35765d5aa37475a144da60e8a2349d1/megatron/core/model_parallel_config.py#L212 This is particular useful for packed inputs where all the paddings are removed.
The text was updated successfully, but these errors were encountered: