diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 80522d62..e7510e58 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -919,7 +919,7 @@ def get_block_compute_costs(self): # CausalSelfAttention (qkv proj + attn out) + MLP LlamaDecoderLayer: 2 * model_config.num_attention_heads * d_qkv * model_config.hidden_size # Q output projection + 2 * model_config.num_key_value_heads * d_qkv * model_config.hidden_size # KV - + 3 * d_ff * model_config.hidden_size # for the MLP (3 because of the gated mechanism), + + 3 * d_ff * model_config.hidden_size, # for the MLP (3 because of the gated mechanism) # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, }