-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refacto generate #199
Refacto generate #199
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR :) Some errors, also ParallelContext is missing self.is_pipeline_last_stage
🤔
(/fsx/phuc/temp/data_stages_env/env) phuc_nguyen@ip-26-0-174-186:/fsx/phuc/temp/data_stages_env/nanotron$ torchrun --nproc_per_node=1 run_generate.py --ckpt-path checkpoints/10/
Traceback (most recent call last):
File "/fsx/phuc/temp/data_stages_env/nanotron/run_generate.py", line 313, in <module>
main()
File "/fsx/phuc/temp/data_stages_env/nanotron/run_generate.py", line 72, in main
config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix())
File "/fsx/phuc/temp/data_stages_env/nanotron/src/nanotron/config/config.py", line 462, in get_config_from_file
config = get_config_from_dict(
File "/fsx/phuc/temp/data_stages_env/nanotron/src/nanotron/config/config.py", line 423, in get_config_from_dict
return from_dict(
File "/fsx/phuc/temp/data_stages_env/env/lib/python3.10/site-packages/dacite/core.py", line 64, in from_dict
value = _build_value(type_=field_type, data=field_data, config=config)
File "/fsx/phuc/temp/data_stages_env/env/lib/python3.10/site-packages/dacite/core.py", line 95, in _build_value
data = _build_value_for_union(union=type_, data=data, config=config)
File "/fsx/phuc/temp/data_stages_env/env/lib/python3.10/site-packages/dacite/core.py", line 113, in _build_value_for_union
return _build_value(type_=types[0], data=data, config=config)
File "/fsx/phuc/temp/data_stages_env/env/lib/python3.10/site-packages/dacite/core.py", line 99, in _build_value
data = from_dict(data_class=type_, data=data, config=config)
File "/fsx/phuc/temp/data_stages_env/env/lib/python3.10/site-packages/dacite/core.py", line 58, in from_dict
raise UnexpectedDataError(keys=extra_fields)
dacite.exceptions.UnexpectedDataError: can not match "monitor_model_states" to any data class field
[2024-06-18 13:09:41,102] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 571644) of binary: /fsx/phuc/temp/data_stages_env/env/bin/python
key_states = self.rotary_embedding(key_states, position_ids=position_ids) | ||
|
||
# Rotate half rotary_embedding | ||
# cos, sin = self.rotary_embedding(value_states, position_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove comments
level=logging.INFO, | ||
rank=0, | ||
# Sample new token | ||
if parallel_context.is_pipeline_last_stage: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't find parallel_context.is_pipeline_last_stage
?
Closing for a new one #202 |
It doesn't support micro-batches yet but I don't think it is necessary to have it for our use case (which is just sanity checking). Plus this make the code easier to understand
TODO: clean rotary for inference