diff --git a/pyproject.toml b/pyproject.toml index 6a0cfb83..9794ab78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "safetensors", "dacite", "tqdm", + "datasets", ] [tool.setuptools.packages.find] @@ -53,6 +54,12 @@ nanosets = [ "numba", ] +s3 = [ + "boto3", + "s3fs", + "s5cmd", +] + [build-system] requires = [ "setuptools", diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index de0fa3c0..0744dd69 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -2,6 +2,7 @@ import os from dataclasses import dataclass, fields from pathlib import Path +from datasets.download.streaming_download_manager import xPath from typing import List, Optional, Type, Union import dacite @@ -91,6 +92,22 @@ def __post_init__(self): self.hf_dataset_splits = "train" +@dataclass +class S3UploadArgs: + """Arguments related to uploading checkpoints on s3""" + + upload_s3_path: xPath + remove_after_upload: bool + s5cmd_numworkers: Optional[int] + s5cmd_concurrency: Optional[int] + s5cmd_path: Optional[xPath] + + def __post_init__(self): + if isinstance(self.upload_s3_path, str): + self.upload_s3_path = xPath(self.upload_s3_path) + if isinstance(self.s5cmd_path, str): + self.s5cmd_path = xPath(self.s5cmd_path) + @dataclass class NanosetDatasetsArgs: dataset_folder: Union[str, dict, List[str]] @@ -146,14 +163,14 @@ class CheckpointsArgs: checkpoint_interval: int save_initial_state: Optional[bool] = False save_final_state: Optional[bool] = False - resume_checkpoint_path: Optional[Path] = None + resume_checkpoint_path: Optional[xPath] = None checkpoints_path_is_shared_file_system: Optional[bool] = False def __post_init__(self): if isinstance(self.checkpoints_path, str): - self.checkpoints_path = Path(self.checkpoints_path) + self.checkpoints_path = xPath(self.checkpoints_path) if isinstance(self.resume_checkpoint_path, str): - self.resume_checkpoint_path = Path(self.resume_checkpoint_path) + self.resume_checkpoint_path = xPath(self.resume_checkpoint_path) @dataclass @@ -338,6 +355,7 @@ class Config: data_stages: Optional[List[DatasetStageArgs]] = None profiler: Optional[ProfilerArgs] = None lighteval: Optional[LightEvalConfig] = None + s3_upload : Optional[S3UploadArgs] = None @classmethod def create_empty(cls): @@ -345,6 +363,10 @@ def create_empty(cls): return cls(**{f.name: None for f in cls_fields}) def __post_init__(self): + + if self.s3_upload is not None: + self.s3_upload.__post_init__() + # Some final sanity checks across separate arguments sections: if self.profiler is not None and self.profiler.profiler_export_path is not None: assert self.tokens.train_steps < 10 diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index a82f0294..761fffc2 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -52,8 +52,8 @@ def _vocab_size_with_padding(orig_vocab_size: int, pg_size: int, make_vocab_size multiple = make_vocab_size_divisible_by * pg_size after = int(ceil(orig_vocab_size / multiple) * multiple) - if after != orig_vocab_size: + print("i'm in") log_rank( f"[Vocab Size Padding] Padded vocab (size: {orig_vocab_size}) with {after - orig_vocab_size} dummy tokens (new size: {after})", logger=logger, diff --git a/src/nanotron/s3_checkpoints/__init__.py b/src/nanotron/s3_checkpoints/__init__.py new file mode 100644 index 00000000..0b32a02a --- /dev/null +++ b/src/nanotron/s3_checkpoints/__init__.py @@ -0,0 +1,4 @@ +from .fsspec import check_path_is_local, fs_copy, fs_open +from .s3_mover import S3Mover + +__all__ = ["S3Mover", "fs_open", "fs_copy", "check_path_is_local"] \ No newline at end of file diff --git a/src/nanotron/s3_checkpoints/fsspec.py b/src/nanotron/s3_checkpoints/fsspec.py new file mode 100644 index 00000000..01786489 --- /dev/null +++ b/src/nanotron/s3_checkpoints/fsspec.py @@ -0,0 +1,38 @@ +import contextlib +from pathlib import Path +from typing import Tuple, Union + +import fsspec +from fsspec.implementations import local + + +def get_filesystem_and_path(path: Path, storage_options=None) -> Tuple[fsspec.AbstractFileSystem, str]: + # Use supported filesystems in `fsspec`. If you need another one, please use `fsspec.registry.register_implementation` + # DO NOT USE `mode` argument as it adds a suffix `0.part` when using `mode="w"`. + fs, _, paths = fsspec.core.get_fs_token_paths(str(path), storage_options=storage_options) + assert len(paths) == 1 + return fs, paths[0] + + +@contextlib.contextmanager +def fs_open( + file: Union[str, Path], + mode="r", +): + # TODO @thomasw21: pass storage options + fs, path = get_filesystem_and_path(file) + with fs.open(path, mode=mode) as f: + yield f + + +def fs_copy( + input_file: Union[str, Path], + output_file: Union[str, Path], +): + """Copy file from input to output (possibly on s3/other fs)""" + with fs_open(input_file, mode="rb") as fi, fs_open(output_file, mode="wb") as fo: + fo.write(fi.read()) + + +def check_path_is_local(path: Path, storage_options=None) -> bool: + return isinstance(get_filesystem_and_path(path=path, storage_options=storage_options)[0], local.LocalFileSystem) diff --git a/src/nanotron/s3_checkpoints/s3_mover.py b/src/nanotron/s3_checkpoints/s3_mover.py new file mode 100644 index 00000000..73c9b793 --- /dev/null +++ b/src/nanotron/s3_checkpoints/s3_mover.py @@ -0,0 +1,439 @@ +import glob +import json +import os +import subprocess +import time +from datetime import datetime +from enum import Enum +from typing import Optional, Union + +import torch +from datasets.download.streaming_download_manager import xPath +from filelock import FileLock, Timeout +from nanotron import distributed as dist +from nanotron import logging +from nanotron.distributed import ProcessGroup +from nanotron.logging import human_format + +logger = logging.get_logger(__name__) + + +class S3Mover: + #TODO @eliebak update the doc to state that it also the function use to download it to the disk with start_downloading + """Take care of uploading a checkpoint to S3 in the background and remove it from the disk. + + Args: + local_path: Path to the checkpoints on the local disk + s3_path: Path to the checkpoints on S3 + remove_after_upload: If True, remove the checkpoint from the disk after uploading it to S3 + s5cmd_numworkers: Number of workers to use for the s5cmd command + s5cmd_concurrency: Concurrency to use for the s5cmd command + s5cmd_path: Path to the s5cmd command + dummy: If True, don't actually upload/remove/etc anything. Useful for simpler multi-processing node and only uploading from one process. + + Usage: + # Create a mover - use dummy=True for all the process that shouldn't do anything (e.g. all but one per node) + mover = S3Mover(local_path=/scratch/my-checkpoints, + s3_path=s3://my-bucket/my-checkpoints, + remove_after_upload=True, + s5cmd_numworkers=96, + s5cmd_concurrency=10, + s5cmd_path=/admin/user/my/bin/s5cmd, + dummy=False) + + while training: + # from times to times update the state + mover_status = mover.update() + ... + + # When saving a checkpoint, check if the previous checkpoint has been uploaded and removed + # in a distributed setting + """ + + class S3MoverState(Enum): + IDLE = "IDLE" + UPLOADING = "UPLOADING" + DOWNLOADING = "DOWNLOADING" + REMOVING_CHECKPOINT = "REMOVING_CHECKPOINT" + + class DummyPopen: + def __init__(self, *args, **kwargs): + pass + + def poll(self): + return 0 + + def communicate(self): + return ("", "") + + def __init__( + self, + local_path: xPath, + s3_path: xPath, + # duplicate_checkpoint_path: Optional[xPath] = None, + post_upload_callback: Optional[callable] = None, + remove_after_upload: Optional[bool] = True, + s5cmd_numworkers: Optional[int] = None, + s5cmd_concurrency: Optional[int] = None, + s5cmd_path: Optional[str] = None, + s5cmd_credentials: Optional[str] = None, + clean_up_local_on_start: bool = False, + dummy: bool = False, + s3_region: str = "us-east-1", + ): + self.process: Optional[Union[subprocess.Popen, S3Mover.DummyPopen]] = None + self.remove_after_upload = remove_after_upload + self.s5cmd_numworkers = s5cmd_numworkers + self.s5cmd_concurrency = s5cmd_concurrency + self.s5cmd_path = s5cmd_path if s5cmd_path is not None else "s5cmd" + self.s5cmd_credentials = s5cmd_credentials + self.lock_file = None + self.dummy = dummy + self.s3_region = s3_region + self.post_upload_callback = post_upload_callback + self.post_upload_callback_outputs = None + + local_path = str(local_path) + if not local_path.startswith("/scratch/"): + self._warning(f"The local path is not on the scratch drive: {local_path}") + if not local_path.endswith("/"): + local_path += "/" + + s3_path = str(s3_path) + if not s3_path.endswith("/"): + s3_path += "/" + + self.local_path = local_path + self.s3_path = s3_path + + s3_bucket, s3_prefix = s3_path.replace("s3://", "").split("/", maxsplit=1) + self.s3_path_direct_link = f"https://s3.console.aws.amazon.com/s3/buckets/{s3_bucket}?region={self.s3_region}&prefix={s3_prefix}&showversions=false" + + self._reset_state() + if clean_up_local_on_start: + self._start_removing() + + def _warning(self, message): + if self.dummy: + return + logger.warning(message) + + def _info(self, message): + if self.dummy: + return + logger.info(message) + + def _reset_state(self): + self.state = self.S3MoverState.IDLE + self.num_uploaded_files = 0 + if self.lock_file is not None: + self._release_lock() + self.lock_file = None + self.stdout = "" + self.start_time: datetime = None + self.cmd = "" + + def _popen(self, cmd: list): + self.stdout = "" + self.start_time = datetime.now() + self.cmd = cmd + if self.dummy: + return self.DummyPopen(cmd) + else: + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + os.set_blocking(process.stdout.fileno(), False) + return process + + def _acquire_lock(self, file_path: str) -> bool: + if self.dummy: + return True + if file_path.endswith("/"): + lock_file_path = file_path[:-1] + ".lock" + else: + lock_file_path = file_path + ".lock" + self.lock_file = FileLock(lock_file_path) + try: + self.lock_file.acquire(timeout=1) + except Timeout: + message = f"[S3] The checkpoint files {lock_file_path} are currently locked by another process. " + self._warning(message) + return False + return True + + def get_state_as_int(self) -> int: + """Return the state as an int""" + if self.state == self.S3MoverState.IDLE: + return 0 + elif self.state == self.S3MoverState.UPLOADING: + return 1 + elif self.state == self.S3MoverState.DOWNLOADING: + return 2 + elif self.state == self.S3MoverState.REMOVING_CHECKPOINT: + return 3 + else: + return -1 + + def _release_lock(self): + if self.dummy: + return + if self.lock_file is not None and self.lock_file.is_locked: + self.lock_file.release() + + def get_current_stdout(self) -> str: + """Return the current stdout of the process if any""" + if self.process is None or isinstance(self.process, self.DummyPopen): + return "" + try: + stdout = self.process.stdout.read() + except ValueError: + stdout = "" # The buffer is already closed: "ValueError: read of closed file" + if stdout: + self.stdout += stdout.decode() + return self.stdout + + def wait_for_completion(self): + while self.state != self.S3MoverState.IDLE: + _ = self.update() + time.sleep(0.5) + + def distributed_wait_for_completion(self, group: Optional[ProcessGroup] = None): + """Wait for the previous checkpoint to be fully uploaded and removed in a distributed setting. + Will wait for all process to be ready + """ + if group is None: + group = dist.torch_dist.distributed_c10d._get_default_group() + + test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda")) + test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())] + dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False) + dist.barrier() + all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list) + if all_saved != group.size() and self.state != self.S3MoverState.IDLE: + self._warning( + f"Waiting previous checkpoint saving is finished - S3Mover {dist.get_rank(group)} still in {self.state} state.", + ) + while all_saved != group.size(): + stdout = self.get_current_stdout() + stdout_lines = [lst for lst in stdout.split("\n") if lst] + if self.state != self.S3MoverState.IDLE: + self._warning( + f"[S3] Waiting {self.state.value}: {all_saved} / {group.size()}. Stdout: {len(stdout_lines)} end: {stdout_lines[-1:]}", + ) + # sync all our saves on NCCL we could do a dist barrier later but this helps us not loosing NCCL connections down the line + test_tensor = torch.tensor([self.is_previous_save_finished()], device=torch.device("cuda")) + test_tensor_list = [torch.zeros_like(test_tensor) for _ in range(group.size())] + dist.all_gather(test_tensor_list, test_tensor, group=group, async_op=False) + dist.barrier() + all_saved = sum(bool(tensor.item()) for tensor in test_tensor_list) + time.sleep(1) + + def is_previous_save_finished(self) -> bool: + """Return True if a potential previous checkpoint has been fully uploaded to S3 + and removed from the drive + """ + self.update() + return self.state == self.S3MoverState.IDLE + + def _start_downloading(self, sub_folder: Optional[str] = None) -> (bool, str): + self._warning( + f"[S3] Downloading checkpoint in background from {self.s3_path} to {self.local_path} (direct link: {self.s3_path_direct_link})" + ) + cmd = [self.s5cmd_path, "--json"] + if self.s5cmd_credentials is not None: + cmd += ["--credentials-file", self.s5cmd_credentials] + if self.s5cmd_numworkers is not None: + cmd += ["--numworkers", str(self.s5cmd_numworkers)] + cmd += ["cp"] + if self.s5cmd_concurrency is not None: + cmd += ["--concurrency", str(self.s5cmd_concurrency)] + cmd += [self.s3_path + "*", self.local_path] + + self.process = self._popen(cmd) + self.state = self.S3MoverState.DOWNLOADING + + return True + + def _post_downloading(self) -> bool: + self.get_current_stdout() + s5cmd_results = [json.loads(i) for i in self.stdout.split("\n") if i] + total_files = len([i for i in s5cmd_results if i["success"]]) + total_not_downloaded_files = len([i for i in s5cmd_results if not i["success"]]) + if total_not_downloaded_files == 0: + all_upload = "all files" + success = True + else: + all_upload = "not all files" + success = False + total_size = sum(i["object"]["size"] for i in s5cmd_results if "size" in i["object"]) + total_time = (datetime.now() - self.start_time).total_seconds() + self._warning( + f"[S3] Successfully downloaded {total_files} files for a total of {human_format(total_size)}B in {total_time}" + f"sec ({all_upload}) from S3 at {self.s3_path} to {self.local_path}" + f"(direct link: {self.s3_path_direct_link})" + ) + return success + + def _start_uploading( + self, + ) -> (bool, str): + # Get a file lock on the first file + local_files = glob.glob(self.full_local_path + "/**/*.*", recursive=True) + + locked = self._acquire_lock(local_files[0]) + if not locked: + return False + + if not os.path.exists(self.full_local_path): + message = f"[S3] Checkpoint {self.full_local_path} does not exist, cannot upload to S3" + self._warning(message) + return False + + self._warning( + f"[S3] Uploading checkpoint in background from {self.full_local_path} to {self.full_s3_path} (direct link: {self.s3_path_direct_link})" + ) + cmd = [self.s5cmd_path, "--json"] + if self.s5cmd_credentials is not None: + cmd += ["--credentials-file", self.s5cmd_credentials] + if self.s5cmd_numworkers is not None: + cmd += ["--numworkers", str(self.s5cmd_numworkers)] + cmd += ["cp", "--exclude", "*.lock", "--exclude", "*.lock.*"] + if self.s5cmd_concurrency is not None: + cmd += ["--concurrency", str(self.s5cmd_concurrency)] + cmd += [self.full_local_path, self.full_s3_path] + + self.process = self._popen(cmd) + self.state = self.S3MoverState.UPLOADING + + return True + + def _post_uploading(self) -> bool: + self.get_current_stdout() + s5cmd_results = [json.loads(i) for i in self.stdout.split("\n") if i] + local_files = glob.glob(self.full_local_path + "/**/*.?*", recursive=True) + total_files = len([i for i in s5cmd_results if i["success"]]) + self.num_uploaded_files = total_files + if len(local_files) == total_files: + all_upload = "all files" + success = True + else: + all_upload = f"not all files: {len(local_files)} out of {total_files}" + success = False + total_size = sum(i["object"]["size"] for i in s5cmd_results if "size" in i["object"]) + total_time = (datetime.now() - self.start_time).total_seconds() + self._warning( + f"[S3] Successfully uploaded {total_files} files for a total of {human_format(total_size)}B in {total_time} sec" + f"({all_upload}) from {self.full_local_path} to S3 at {self.full_s3_path} " + f"(direct link: {self.s3_path_direct_link})" + ) + if self.post_upload_callback: + self.post_upload_callback_outputs = self.post_upload_callback(uploaded_files=s5cmd_results) + self._release_lock() + return success + + def _start_removing(self) -> (bool, str): + top_dir_in_local_checkpoint = [dir for dir in glob.glob(self.local_path + "/*") if os.path.isdir(dir)] + names_dir = [os.path.basename(dir) for dir in top_dir_in_local_checkpoint] + if len(names_dir) == 0: + # If the local is already empty or if we have already started duplicating in another process we skip with a noop + self._warning("[S3] Local checkpoint empty. skipping removal") + cmd = ["echo", "'skipping'"] + self.process = self._popen(cmd) + self.state = self.S3MoverState.REMOVING_CHECKPOINT + return True + + self._warning(f"[S3] Removing checkpoint in background: {names_dir}") + locked = self._acquire_lock(top_dir_in_local_checkpoint[0]) + if not locked: + return False + cmd = ["rm", "-rfv"] + top_dir_in_local_checkpoint + self.process = self._popen(cmd) + self.state = self.S3MoverState.REMOVING_CHECKPOINT + return True + + def _post_removing(self) -> bool: + self.get_current_stdout() + local_files = [ + loc_f + for loc_f in self.stdout.split("\n") + if "directory" not in loc_f.lower() and loc_f and ".lock" not in loc_f + ] + if len(local_files) == self.num_uploaded_files: + all_removed = "all files" + success = True + else: + all_removed = "not all files" + success = False + self._release_lock() + total_time = (datetime.now() - self.start_time).total_seconds() + self._warning( + f"[S3] Successfully removed {len(local_files)} local files ({all_removed}) from {self.local_path} (uploaded to {self.s3_path_direct_link}) in {total_time}" + ) + return success + + def update(self) -> (str, str): + """Update the state of the mover: UPLOADING => REMOVING_DUPLICATED => DUPLICATING => REMOVING_CHECKPOINT => IDLE + + Returns: + (str, str): The state and the stdout of the process if any + """ + if self.process is None: + self._reset_state() + return self.state, self.stdout + + return_code = self.process.poll() + if return_code is None: + # Still running + return self.state, self.stdout + if return_code != 0: + self.get_current_stdout() + self._warning( + f"[S3] Error running command {self.cmd} during process {self.state.value}, " + f"return code {return_code}, return message {self.stdout}" + ) + return self.state, self.stdout + if self.state == self.S3MoverState.DOWNLOADING: + self._post_downloading() + self._reset_state() + elif self.state == self.S3MoverState.UPLOADING: + self._post_uploading() + if self.remove_after_upload: + self._start_removing() + else: + self._reset_state() + elif self.state == self.S3MoverState.REMOVING_CHECKPOINT: + self._post_removing() + self._reset_state() + + return self.state.value, self.stdout + + def start_uploading(self, sub_folder=None): + """Start uploading last saved checkpoint to S3 in the background. + + After running this method, you should call regularly `update` to update the + state to duplicating and then removing. + + For a blocking upload, call `wait_for_completion` or `distributed_wait_for_completion` after calling this method. + """ + self.update() + if self.state != self.S3MoverState.IDLE: + message = "[S3] Cannot move to S3 as the previous checkpoint has not been uploaded and removed" + self._warning(message) + return False + self.full_local_path = self.local_path + (f"/{sub_folder}" if sub_folder else "") + self.full_s3_path = self.s3_path + (f"/{sub_folder}" if sub_folder else "") + return self._start_uploading() + + def start_downloading(self): + """Start downloading a checkpoint from S3 in the background. + + After running this method, you should call regularly `update` to update the + state. + + For a blocking download, call `wait_for_completion` or `distributed_wait_for_completion` after calling this method. + """ + self.update() + if self.state != self.S3MoverState.IDLE: + message = f"[S3] Cannot download from S3 as the state is not IDLE but {self.state.value}" + self._warning(message) + return False + return self._start_downloading() diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 286008ac..e9ed2572 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -1,6 +1,9 @@ from pathlib import Path from typing import Optional, cast +from datasets.download.streaming_download_manager import xPath +import os +from nanotron.s3_checkpoints import S3Mover, check_path_is_local, fs_open import torch from torch import nn from torch.nn.parallel import DistributedDataParallel @@ -241,7 +244,7 @@ def load( return checkpoint_metadata -def parse_ckpt_path(config: Config) -> Optional[Path]: +def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Optional[Path]: """Parse checkpoint path from config and download checkpoint from S3 if needed. Args: @@ -251,33 +254,71 @@ def parse_ckpt_path(config: Config) -> Optional[Path]: Path to checkpoint or None if no checkpoint. """ load_from_candidate = config.checkpoints.resume_checkpoint_path - if load_from_candidate is None: - return None - - latest_meta_path: Path = config.checkpoints.resume_checkpoint_path / "latest.txt" - if latest_meta_path.exists(): - with open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi: - # TODO @thomasw21: make a better structure system so that we get typing correct - load_from_candidate = int(fi.read()) - checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate) - - elif (config.checkpoints.resume_checkpoint_path / MODEL_CONFIG_FILE_NAME).exists(): - # we assume that the checkpoint path is a path to a checkpoint - checkpoint_path = config.checkpoints.resume_checkpoint_path + if load_from_candidate is not None: + if check_path_is_local(load_from_candidate): + latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt" + if latest_meta_path.exists(): + with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi: + # TODO @thomasw21: make a better structure system so that we get typing correct + load_from_candidate = int(fi.read()) + checkpoint_path = config.checkpoints.resume_checkpoint_path / str(load_from_candidate) + + elif (config.checkpoints.resume_checkpoint_path / "model_config.json").exists(): + # we assume that the checkpoint path is a path to a checkpoint + checkpoint_path = config.checkpoints.resume_checkpoint_path + + else: + log_rank( + f"No previous checkpoint found in: {latest_meta_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + return None - else: - log_rank( - f"No previous checkpoint found in: {latest_meta_path}", - logger=logger, - level=logging.INFO, - rank=0, - ) - return None + log_rank( + f"Loading checkpoint from {checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + else: + latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt" + if latest_meta_path.exists(): + # if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint + with fs_open(latest_meta_path, mode="r") as fi: + latest_iteration = int(fi.read()) + s3_path = config.checkpoints.resume_checkpoint_path / str(latest_iteration) # load_path + checkpoint_path = config.checkpoints.checkpoints_path / str(latest_iteration) # save_path + elif config.checkpoints.resume_checkpoint_path.exists(): + # we assume that the checkpoint path is a path to a checkpoint + s3_path = config.checkpoints.resume_checkpoint_path # load_path + checkpoint_path = config.checkpoints.checkpoints_path / load_from_candidate.name # save_path + else: + log_rank( + f"No previous checkpoint found in: {config.checkpoints.resume_checkpoint_path}\n Initializing from scratch.", + logger=logger, + level=logging.WARNING, + rank=0, + ) + return None + log_rank( + f"Downloading checkpoint from S3 in {checkpoint_path} ", + logger=logger, + level=logging.WARNING, + rank=0, + ) + # Download checkpoint from S3 + s3_mover = S3Mover( + local_path=os.path.join(checkpoint_path), + s3_path=os.path.join(s3_path), + s5cmd_numworkers=config.s3_upload.s5cmd_numworkers, + s5cmd_concurrency=config.s3_upload.s5cmd_concurrency, + s5cmd_path=config.s3_upload.s5cmd_path, + dummy=bool(int(os.environ.get("LOCAL_RANK", None)) != 0), + ) + s3_mover.distributed_wait_for_completion(parallel_context.world_pg) + s3_mover.start_downloading() + s3_mover.distributed_wait_for_completion(parallel_context.world_pg) - log_rank( - f"Loading checkpoint from {checkpoint_path}", - logger=logger, - level=logging.INFO, - rank=0, - ) - return checkpoint_path + return checkpoint_path diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index bef629c1..9725d45b 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -19,6 +19,7 @@ cast, ) +from nanotron.s3_checkpoints import S3Mover, check_path_is_local import torch from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader @@ -148,12 +149,14 @@ def __init__( data_parallel_size=self.config.parallelism.dp, expert_parallel_size=self.config.parallelism.expert_parallel_size, ) - + self.pre_init() # Set log levels set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging) + + # Log benchmark info if os.environ.get("NANOTRON_BENCHMARK", "0") == "1": log_throughput(self.config, self.parallel_context) @@ -255,10 +258,25 @@ def __init__( self.post_init() def pre_init(self): - pass + self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context) def post_init(self): - pass + # S3 Mover and save initial state + if self.config.s3_upload is not None: + # Only local rank 0 should upload + dummy = bool(int(os.environ.get("LOCAL_RANK", None)) != 0) + self.s3_mover = S3Mover( + local_path=self.config.checkpoints.checkpoints_path, + s3_path=self.config.s3_upload.upload_s3_path, + # duplicate_checkpoint_path=self.config.checkpoints.resume_checkpoint_path, + remove_after_upload=self.config.s3_upload.remove_after_upload, + s5cmd_numworkers=self.config.s3_upload.s5cmd_numworkers, + s5cmd_concurrency=self.config.s3_upload.s5cmd_concurrency, + s5cmd_path=self.config.s3_upload.s5cmd_path, + dummy=dummy, + ) + else: + self.s3_mover = None def pre_training(self, *args, **kwargs): self._print_training_plan() @@ -281,11 +299,15 @@ def pre_training(self, *args, **kwargs): ) def post_train_step(self): - pass - def post_training(self): - pass + # Update our background upload/removal of checkpoints + if self.s3_mover is not None: + self.s3_mover.update() + def post_training(self): + if self.s3_mover is not None: + self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg) + def _print_training_plan(self): if hasattr(self.config, "data_stages") and self.config.data_stages is not None: stages_info = "".join( @@ -689,20 +711,21 @@ def _init_model_instance(self) -> NanotronModel: def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model - # Load or initialize model weights - self.init_checkpoint_path = parse_ckpt_path(config=self.config) + # Load or initialize model weights reloaded_from_checkpoint = False if self.init_checkpoint_path is not None: - # Reload from a training checkpoint - log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0) - self.param_shard_metadata = load_weights( - model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path - ) - reloaded_from_checkpoint = True + # Load from a pre existing checkpoint + if check_path_is_local(self.init_checkpoint_path): + # Reload from a training checkpoint + log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0) + self.param_shard_metadata = load_weights( + model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path + ) + reloaded_from_checkpoint=True if not reloaded_from_checkpoint: log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0) if isinstance(self.config.model.init_method, ExistingCheckpointInit): - # Initialize model from an pretrained model checkpoint + # Initialize model from an pretrained model checkpoint (without optimizer, lr_scheduler...) self.param_shard_metadata = load_weights( model=unwrapped_model, parallel_context=self.parallel_context, @@ -830,11 +853,18 @@ def setup_log_writers( return loggerwriter - def pre_save_checkpoint(self): - pass + def pre_save_checkpoint(self) -> Path: + if self.s3_mover is not None: + self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) + if self.s3_mover.post_upload_callback_outputs is not None: + slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs + self.log_object({"job_id": slurm_job_id, "log": slurm_log}, "slurm_eval") def post_save_checkpoint(self): - pass + # Upload to S3 + if self.s3_mover is not None: + self.s3_mover.start_uploading() + def save_checkpoint(self) -> Path: self.pre_save_checkpoint() diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index b3831801..cb187f77 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -3,6 +3,7 @@ import os import random import socket +import re from contextlib import ExitStack, contextmanager from typing import ContextManager, List, Optional