Skip to content

Commit

Permalink
Merge branch 'main' into mem_fix_async
Browse files Browse the repository at this point in the history
  • Loading branch information
AleHD committed Jul 23, 2024
2 parents 6d85d03 + 5f82f7a commit 81e7a54
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 5 deletions.
12 changes: 12 additions & 0 deletions examples/mamba/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ pip install -r requirements.txt

> https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5
## Bug related to nanotron
Encountered the following issue when ran train_mamba.sh:
```
causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv
```
Solved this by doing:
pip uninstall mamba-ssm
pip install causal_conv1d==1.1.1
pip install mamba-ssm --no-cache-dir
https://github.com/state-spaces/mamba/issues/169


## Credits
Credits to the following repositories from which the code was adapted:
- https://github.com/state-spaces/mamba
2 changes: 2 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ParallelismArgs:
pp_engine: Pipeline engine to use between "1f1b" and "afab"
tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism
tp_linear_async_communication: Whether to use async communication in TP linear layers
recompute_layer: Whether to recompute each Transformer layer to save memory.
"""

dp: int
Expand All @@ -31,6 +32,7 @@ class ParallelismArgs:
pp_engine: Optional[PipelineEngine] = None
tp_mode: Optional[TensorParallelLinearMode] = None
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False

tp_recompute_allgather: bool = True

Expand Down
32 changes: 27 additions & 5 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from torch import nn
from torch.utils.checkpoint import CheckpointFunction

from nanotron import distributed as dist
from nanotron import logging
Expand Down Expand Up @@ -619,12 +620,14 @@ def __init__(

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)

def forward(

self.recompute_layer = parallel_config.recompute_layer

def _core_forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
) -> List[Union[torch.Tensor, TensorPointer]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

Expand All @@ -637,12 +640,31 @@ def forward(
hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
hidden_states = hidden_states + residual

return hidden_states, output["sequence_mask"]

def _checkpointed_forward(
self,
hidden_states: torch.Tensor,
sequence_mask: torch.Tensor,
) -> List[torch.Tensor]:
return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask)

def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:

if self.recompute_layer and not isinstance(hidden_states, TensorPointer):
hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask)
else:
hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask)

return {
"hidden_states": hidden_states,
"sequence_mask": output["sequence_mask"],
"sequence_mask": sequence_mask,
}


class Embedding(nn.Module, AttachableStore):
def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]):
super().__init__()
Expand Down

0 comments on commit 81e7a54

Please sign in to comment.