diff --git a/run_generate.py b/run_generate.py index f389770d..ad116d73 100644 --- a/run_generate.py +++ b/run_generate.py @@ -7,9 +7,7 @@ torchrun --nproc_per_node=4 run_generate.py ---ckpt-path checkpoints/test/4 ``` """ - import argparse -import os from pathlib import Path import torch @@ -21,12 +19,14 @@ ParallelismArgs, get_config_from_file, ) +from nanotron.distributed import get_global_rank from nanotron.generation.decode import ( - GenerationInput, - TokenizerConfig, - decode_text, - decode_tokenized, + GenerationInputs, + GenerationStates, + run_one_inference_step, ) +from nanotron.generation.generate_store import Store +from nanotron.generation.sampler import BasicSampler, GreedySampler, SamplerType, TopKSampler, TopPSampler from nanotron.logging import log_rank, set_ranks_logging_level from nanotron.models import build_model from nanotron.parallel import ParallelContext @@ -34,7 +34,6 @@ from nanotron.parallel.pipeline_parallel.engine import ( OneForwardOneBackwardPipelineEngine, ) -from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.random import ( RandomStates, @@ -50,6 +49,7 @@ except ImportError: AutoTokenizer = None + logger = logging.get_logger(__name__) @@ -57,9 +57,10 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path") parser.add_argument("--dp", type=int, default=1) - parser.add_argument("--pp", type=int, default=0) - parser.add_argument("--tp", type=int, default=0) + parser.add_argument("--pp", type=int, default=1) + parser.add_argument("--tp", type=int, default=1) parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") + parser.add_argument("--use-cache", action="store_true", help="Use cache for generation") return parser.parse_args() @@ -73,9 +74,9 @@ def main(): tokenizer_path = config.tokenizer.tokenizer_name_or_path parallel_config = ParallelismArgs( - dp=args.dp or config.parallelism.dp, - pp=args.pp or config.parallelism.pp, - tp=args.tp or config.parallelism.tp, + dp=args.dp, + pp=args.pp, + tp=args.tp, pp_engine=OneForwardOneBackwardPipelineEngine(), tp_mode=TensorParallelLinearMode.ALL_REDUCE, tp_linear_async_communication=False, @@ -163,86 +164,147 @@ def main(): tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.padding_side = "left" tokenizer.truncation_side = "left" # TODO @nouamane: do we want this? + dummy_inputs = [ "The future of AI is", # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", - "def fib(n)", + # "def fib(n)", # 'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.', # "Advancements in technology will lead to", # "Tomorrow's world is shaped by", ] - outputs = decode_text( - input_iter=(GenerationInput(text=text) for text in dummy_inputs), - tokenizer=tokenizer, - # TODO @thomasw21: From ModelWithLoss extract the model. - model=model.model, - parallel_context=parallel_context, - max_new_tokens=args.max_new_tokens, - max_micro_batch_size=2, - generation_config=GenerationArgs(sampler="greedy", use_cache=True), - tokenizer_config=TokenizerConfig(max_input_length=None), - is_bench=os.environ.get("USE_BENCH", "0") == "1", + log_rank(f"Using cache for generation: {args.use_cache}", logger=logger, level=logging.INFO, rank=0) + + # NOTE: This doesn't support micro-batches and batch inference + device = torch.cuda.current_device() + generation_config = GenerationArgs(sampler="greedy", use_cache=args.use_cache) + logits_are_batch_first = True + + if generation_config: + if isinstance(generation_config.sampler, str): + sampler_type = SamplerType(generation_config.sampler.upper()) + else: + sampler_type = generation_config.sampler + else: + sampler_type = SamplerType.GREEDY + + tokenized_prompts = tokenizer( + dummy_inputs, + return_tensors="pt", + return_attention_mask=True, + padding=True, ) - for output in outputs: - input_ids = output.input_ids - generated_ids = output.generation_ids - if isinstance(input_ids, TensorPointer): - assert isinstance(generated_ids, TensorPointer) - continue - assert isinstance(generated_ids, torch.Tensor) - - log_rank( - f"input: {tokenizer.decode(input_ids, clean_up_tokenization_spaces=False)[:1000]}", - logger=logger, - level=logging.INFO, - rank=0, - ) + tokenized_prompts["input_ids"] = tokenized_prompts["input_ids"].to(device) + tokenized_prompts["attention_mask"] = tokenized_prompts["attention_mask"].to(dtype=torch.bool, device=device) + + store = Store() + batch_prompts = None + + for i in range(args.max_new_tokens): + + if generation_config.use_cache: + # Prepare the batch prompts + batch_prompts = GenerationStates( + new_input_ids=tokenized_prompts["input_ids"] + if i == 0 + else tokenized_prompts["input_ids"][:, -1].unsqueeze(0), + new_input_mask=tokenized_prompts["attention_mask"] + if i == 0 + else tokenized_prompts["attention_mask"][:, -1].unsqueeze(0), + store=store, + generation_ids=tokenized_prompts["input_ids"], + generation_mask=tokenized_prompts["attention_mask"], + ) + else: + batch_prompts = GenerationInputs( + input_ids=tokenized_prompts["input_ids"], + input_masks=tokenized_prompts["attention_mask"], + ) - log_rank( - f"generation: {tokenizer.decode(generated_ids[len(input_ids) :], clean_up_tokenization_spaces=False)}", - logger=logger, - level=logging.INFO, - rank=0, + logits = run_one_inference_step( + model, batch_prompts, parallel_context, device, use_cache=generation_config.use_cache, store=store ) - log_rank( - "--------------------------------------------------", - logger=logger, - level=logging.INFO, - rank=0, + # Sample new token + if parallel_context.is_pipeline_last_stage: + assert logits is not None and isinstance(logits, torch.Tensor) + + # Get sampler + if sampler_type == SamplerType.GREEDY: + sampler = GreedySampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.TOP_K: + sampler = TopKSampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.TOP_P: + sampler = TopPSampler(pg=parallel_context.tp_pg) + elif sampler_type == SamplerType.BASIC: + sampler = BasicSampler(pg=parallel_context.tp_pg) + else: + raise NotImplementedError(f"Sampler type {sampler_type} is not implemented") + + if logits_are_batch_first: + logits = logits.transpose(0, 1) + + # Predict next token + next_token = sampler(sharded_logits=logits[:, -1]) + + # Extend the tokenized prompts to insert the new token + tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token], dim=-1) + tokenized_prompts["attention_mask"] = torch.cat( + [ + tokenized_prompts["attention_mask"], + torch.ones((tokenized_prompts["attention_mask"].shape[0], 1), dtype=torch.bool, device=device), + ], + dim=-1, + ) + else: + # Extend the tokenized prompts to receive the new token + tokenized_prompts["input_ids"] = torch.zeros( + (tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), + dtype=torch.int64, + device=device, + ) + tokenized_prompts["attention_mask"] = torch.zeros( + ( + tokenized_prompts["attention_mask"].shape[0], + tokenized_prompts["attention_mask"].shape[1] + 1, + ), + dtype=torch.bool, + device=device, + ) + + # Broadcast the new token to all the pipeline stages + dist.broadcast( + tokenized_prompts["input_ids"], + src=get_global_rank( + group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank + ), + group=parallel_context.pp_pg, ) - else: - outputs = decode_tokenized( - input_ids=torch.zeros(1, 1).to(dtype=torch.int64, device="cuda"), - input_mask=torch.ones(1, 1).to(dtype=torch.bool, device="cuda"), - model=model.model, - parallel_context=parallel_context, - generation_config=GenerationArgs(sampler="greedy", use_cache=True), - max_micro_batch_size=1, - max_new_tokens=12, - returns_logits=False, - ) - for output in outputs: - input_ids = output.input_ids - generated_ids = output.generation_ids - if isinstance(input_ids, TensorPointer): - assert isinstance(generated_ids, TensorPointer) - continue - assert isinstance(generated_ids, torch.Tensor) - log_rank( - f"generation: {generated_ids[len(input_ids) :]}", - logger=logger, - level=logging.INFO, - rank=0, + dist.broadcast( + tokenized_prompts["attention_mask"], + src=get_global_rank( + group=parallel_context.pp_pg, group_rank=parallel_context.pipeline_parallel_last_rank + ), + group=parallel_context.pp_pg, ) - log_rank( - "--------------------------------------------------", - logger=logger, - level=logging.INFO, - rank=0, - ) + # Decode the generated text + if dist.get_rank() == 0: + for i, prompt in enumerate(dummy_inputs): + if generation_config.use_cache: + tokenized_outputs = torch.cat( + [tokens.view(1, -1) for tokens in batch_prompts.generation_ids], dim=1 + ) + outputs = tokenizer.decode(tokenized_outputs[0], clean_up_tokenization_spaces=False) + else: + tokenized_outputs = tokenized_prompts["input_ids"][ + i, tokenized_prompts["input_ids"].shape[1] - args.max_new_tokens : + ] + outputs = tokenizer.decode(tokenized_outputs, clean_up_tokenization_spaces=False) + + log_rank(f"Input: {prompt}", logger=logger, level=logging.INFO, rank=0) + log_rank(f"Output: {outputs}", logger=logger, level=logging.INFO, rank=0) dist.barrier() diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 57225243..2630e1d6 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -48,6 +48,9 @@ class LlamaConfig: rms_norm_eps: float = 1e-6 rope_scaling: Optional[dict] = None rope_theta: float = 10000.0 + rope_interleaved: bool = ( + True # The default value has been True, but for loading Llama3 checkpoints you have to set it to False + ) tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 6ab71fad..dc021085 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -772,6 +772,67 @@ def generator(): ) +@torch.inference_mode() +def run_one_inference_step(model, batch, parallel_context, device, use_cache, store): + if dist.get_world_size(group=parallel_context.pp_pg) == 1: + if use_cache: + with attach_store(model=model, store=store): + return model.model(batch.new_input_ids, batch.new_input_mask) + return model.model(batch.input_ids, batch.input_masks) + + pipeline_state = PipelineEvalBatchState() + with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state): + batch_size = batch.new_input_ids.shape[0] if use_cache else batch.input_ids.shape[0] + seq_len = batch.new_input_ids.shape[1] if use_cache else batch.input_ids.shape[1] + + # Preallocate memory for output logits. + logits = None + if parallel_context.is_pipeline_last_stage: + logits = torch.empty((seq_len, batch_size, model.config.vocab_size), dtype=torch.float32, device=device) + + if use_cache: + batch2use = GenerationStates( + new_input_ids=batch.new_input_ids + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + new_input_mask=batch.new_input_mask + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + store=store, + generation_ids=batch.generation_ids, + generation_mask=batch.generation_mask, + ) + with attach_store(model=model, store=store): + output_tensor = model.model(batch2use.new_input_ids, batch2use.new_input_mask) + else: + batch2use = GenerationInputs( + input_ids=batch.input_ids + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + input_masks=batch.input_masks + if parallel_context.is_pipeline_first_stage + else TensorPointer(group_rank=parallel_context.pipeline_parallel_prev_rank), + ) + + output_tensor = model.model(batch2use.input_ids, batch2use.input_masks) + + nb_send = len(pipeline_state.microbatches_activations_to_send) + assert nb_send <= 2 + for _ in range(nb_send): + # Send activations to the next stage + # Send attention_mask to the next stage + pipeline_state.run_communication() + + # Copy logits. + if parallel_context.is_pipeline_last_stage: + logits = output_tensor + + # Wait for all the communication to complete. + dist.barrier(group=parallel_context.world_pg) + + return logits + + # Distributed utilities def broadcast_tensors( tensors: List[Union[torch.Tensor, TensorPointer]], group_src: int, group: Optional[ProcessGroup] = None diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ca8894b9..75367953 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn @@ -46,75 +46,71 @@ logger = logging.get_logger(__name__) +def rotate_half(x, interleaved): + if interleaved: + x1, x2 = x[..., ::2], x[..., 1::2] + else: + split = x.shape[-1] // 2 + x1, x2 = x[..., :split], x[..., split:] + return torch.cat((-x2, x1), dim=-1) + + class RotaryEmbedding(nn.Module): - def __init__(self, dim: int, end: int, theta: float = 10000.0): + def __init__(self, dim: int, end: int, theta: float = 500000.0): super().__init__() - assert dim % 2 == 0 self.dim = dim self.end = end self.theta = theta - # TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ... - # TODO @thomasw21: Complex buffers break DDP, instead we store float and view them as complex - self.freqs_cis: torch.Tensor - self._initialized_buffer = False + self.init_rotary_embeddings() def init_rotary_embeddings(self): - if self._initialized_buffer is True: - # Buffer if already initialized - return - self.register_buffer( - "freqs_cis", - torch.empty(self.end, self.dim // 2, 2, dtype=torch.float, device="cuda"), - persistent=False, - ) - assert self.freqs_cis.device.type == "cuda" - # TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert - if self.freqs_cis.dtype != torch.float: - self.freqs_cis = self.freqs_cis.to(torch.float) - assert self.freqs_cis.dtype == torch.float - freqs = 1.0 / ( - self.theta - ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda")[: (self.dim // 2)] / self.dim) - ) - t = torch.arange(self.end, device="cuda") - freqs = torch.outer(t, freqs).float() - complex_freqs = torch.polar(torch.ones_like(freqs), freqs) - freqs = torch.view_as_real(complex_freqs) - self.freqs_cis.copy_(freqs) - self._initialized_buffer = True + inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda") / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + @torch.no_grad() def forward( self, x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk] position_ids: Optional[torch.LongTensor], # [batch_size, seq_length] ): - batch_size, seq_length, num_heads, inner_dim = x.shape - while ( - position_ids is not None and position_ids[-1, -1] >= self.end - ) or seq_length >= self.end: # TODO @nouamane: check if this causes cpu-gpu sync - self.end *= 2 - self._initialized_buffer = False - if self._initialized_buffer is False: - print(f"Initializing rotary embeddings with end={self.end}") - self.init_rotary_embeddings() - dtype = x.dtype - assert inner_dim % 2 == 0 - x = x.view( - batch_size, seq_length, num_heads, inner_dim // 2, 2 - ) # [batch_size, q_length, num_heads, inner_dim] - if x.dtype == torch.bfloat16: - x = x.float() - complex_x = torch.view_as_complex(x) # [batch_size, q_length, num_heads, inner_dim // 2] - if position_ids is None: - freqs_cis = self.freqs_cis[None, :seq_length, None, :] - else: - # TODO(kunhao): Should None follow the num_heads dimension? - if position_ids[-1, -1] < 0 or position_ids[-1, -1] >= self.end: # Quick test hopefully - raise ValueError(f"Position ids must be in the range [0, {self.end}), but got {position_ids}") - freqs_cis = self.freqs_cis[position_ids][:, :, None, :] - complex_freqs = torch.view_as_complex(freqs_cis) - x_out = torch.view_as_real(complex_x * complex_freqs).view(batch_size, seq_length, num_heads, inner_dim) - return x_out.type(dtype) + # x: [bs, num_attention_heads, seq_len, head_size] + # print("rotary") + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def apply_rotary_pos_emb(q, k, cos, sin, interleaved, unsqueeze_dim=2): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q, interleaved) * sin) + k_embed = (k * cos) + (rotate_half(k, interleaved) * sin) + return q_embed, k_embed class GLUActivation(nn.Module): @@ -322,9 +318,12 @@ def __init__( end=config.max_position_embeddings, theta=config.rope_theta, ) + self.rope_interleaved = config.rope_interleaved # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) - self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True) + self.flash_rotary_embedding = FlashRotaryEmbedding( + dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved + ) self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, @@ -403,8 +402,12 @@ def forward( # Compute rotary embeddings # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache old_rotary_embed_end = self.rotary_embedding.end - query_states = self.rotary_embedding(query_states, position_ids=position_ids) - key_states = self.rotary_embedding(key_states, position_ids=position_ids) + # query_states = self.rotary_embedding(query_states, position_ids=position_ids) + # key_states = self.rotary_embedding(key_states, position_ids=position_ids) + cos, sin = self.rotary_embedding(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, interleaved=self.rope_interleaved + ) if "key" not in store: # First inference iteration (Prefill) diff --git a/src/nanotron/parallel/context.py b/src/nanotron/parallel/context.py index e04e26f5..ceab9f9c 100644 --- a/src/nanotron/parallel/context.py +++ b/src/nanotron/parallel/context.py @@ -1,5 +1,5 @@ import os -from typing import Literal, Tuple, Annotated +from typing import Literal, Tuple import numpy as np import torch @@ -62,6 +62,25 @@ def __init__( self._init_parallel_groups() + self.pipeline_parallel_last_rank = self.pipeline_parallel_size - 1 + self.is_pipeline_first_stage = self.pp_pg.rank() == 0 + self.is_pipeline_last_stage = self.pp_pg.rank() == self.pipeline_parallel_last_rank + self.pipeline_parallel_next_rank = ( + None + if self.is_pipeline_last_stage + else self.get_global_rank( + self.expert_pg.rank(), self.pp_pg.rank() + 1, self.dp_pg.rank(), self.tp_pg.rank() + ) + ) + + self.pipeline_parallel_prev_rank = ( + None + if self.is_pipeline_first_stage + else self.get_global_rank( + self.expert_pg.rank(), self.pp_pg.rank() - 1, self.dp_pg.rank(), self.tp_pg.rank() + ) + ) + def _init_parallel_groups(self): """Initialize 3D parallelism's all process groups.""" dist.barrier() @@ -152,4 +171,4 @@ def get_global_rank( :return: numpy.int64, The global rank. """ - return self.world_rank_matrix[ep_rank, pp_rank, dp_rank, tp_rank] \ No newline at end of file + return self.world_rank_matrix[ep_rank, pp_rank, dp_rank, tp_rank]