Skip to content

Commit

Permalink
inference part
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhhjjj committed Sep 5, 2024
1 parent b15fbc7 commit 5ded28c
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,14 +394,24 @@ def __init__(
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
self.rotary_embedding = RotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
theta=config.rope_theta,
)
if config.rope_interleaved:
self.rotary_embedding = RotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
theta=config.rope_theta,
)
else:
self.rotary_embedding = LlamaRotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
theta=config.rope_theta,
)
self.rope_interleaved = config.rope_interleaved

# NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True)
self.flash_rotary_embedding = FlashRotaryEmbedding(
dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved
)

self.o_proj = TensorParallelRowLinear(
config.num_attention_heads * self.d_qk,
Expand Down Expand Up @@ -480,8 +490,14 @@ def forward(
# Compute rotary embeddings
# Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
old_rotary_embed_end = self.rotary_embedding.end
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)
# interleaved version.
if self.rope_interleaved:
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)
# non interleaved version.
else:
cos, sin = self.rotary_embedding(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if "key" not in store:
# First inference iteration (Prefill)
Expand Down Expand Up @@ -620,7 +636,7 @@ def forward(
cache_seqlens=position_offsets.contiguous(),
softmax_scale=softmax_scale,
causal=True,
rotary_interleaved=False, # GPT-NeoX style
rotary_interleaved=False, # the value is not used unless rotary_cos/sin is provided. https://github.com/Dao-AILab/flash-attention
)

store.update(
Expand Down

0 comments on commit 5ded28c

Please sign in to comment.