Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhhjjj committed Sep 2, 2024
1 parent f40fa05 commit b15fbc7
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 1,258 deletions.
3 changes: 3 additions & 0 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class LlamaConfig:
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0
rope_interleaved: bool = (
False # The default value has been True, but for loading Llama3 checkpoints you have to set it to False
)
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
Expand Down
13 changes: 8 additions & 5 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def init_rotary_embeddings(self):
inv_freq = 1.0 / (
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim)
) # important to compute on CPU
# inv_freq = apply_scaling(inv_freq) # if LLaMA 3.1
self.register_buffer(
"inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False
)
Expand All @@ -148,9 +147,6 @@ def init_rotary_embeddings(self):
) # make it float32 before copy to avoid precision loss during copy_
self.inv_freq.copy_(inv_freq)

saved_inv_freq = torch.load("/fsx/haojun/LLaMA/.cache/activation_values/inv_freq.pt")
assert torch.equal(self.inv_freq.cpu(), saved_inv_freq), "inv_freq mismatch."

@torch.no_grad()
def forward(
self,
Expand Down Expand Up @@ -807,7 +803,14 @@ def __init__(
module_input_keys={"input_ids", "input_mask"},
module_output_keys={"input_embeds"},
)

log_rank(f"Initialize RoPE Theta = {config.rope_theta}", logger=logger, level=logging.INFO, rank=0)
if config.rope_interleaved:
log_rank(
"The RoPE interleaved version differs from the Transformers implementation. It's better to set rope_interleaved=False if you need to convert the weights to Transformers",
logger=logger,
level=logging.INFO,
rank=0,
)
self.decoder = nn.ModuleList(
[
PipelineBlock(
Expand Down
Loading

0 comments on commit b15fbc7

Please sign in to comment.