Skip to content

Commit

Permalink
Enable gpu/tpu training with XManager.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 549429009
  • Loading branch information
Uncertainty Baselines Team authored and copybara-github committed Sep 5, 2023
1 parent f5f6f50 commit 3c78dd6
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 56 deletions.
21 changes: 8 additions & 13 deletions experimental/shoshin/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,14 @@ def filter_subgroup(x, label=subgroup_label):

def apply_batch(dataloader, batch_size):
"""Apply batching to dataloader."""
dataloader.train_splits = [
data.batch(batch_size) for data in dataloader.train_splits
]
dataloader.val_splits = [
data.batch(batch_size) for data in dataloader.val_splits
]
num_splits = len(dataloader.train_splits)
train_ds = gather_data_splits(
list(range(num_splits)), dataloader.train_splits)
val_ds = gather_data_splits(list(range(num_splits)), dataloader.val_splits)
dataloader.train_ds = train_ds
dataloader.eval_ds['val'] = val_ds
for (k, v) in dataloader.eval_ds.items():
# TODO(jihyeonlee): support drop remainder in data split and apply here
dataloader.train_ds = dataloader.train_ds.batch(
batch_size, drop_remainder=True
)
dataloader.eval_ds['val'] = dataloader.eval_ds['val'].batch(
batch_size, drop_remainder=True
)
for k, v in dataloader.eval_ds.items():
if k != 'val':
dataloader.eval_ds[k] = v.batch(batch_size)
return dataloader
Expand Down
28 changes: 15 additions & 13 deletions experimental/shoshin/train_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
def main(_) -> None:
config = FLAGS.config
base_config.check_flags(config)
strategy = train_tf_lib.create_strategy(tpu_bns=FLAGS.tpu)

if FLAGS.keep_logs and not config.training.log_to_xm:
if not tf.io.gfile.exists(config.output_dir):
Expand Down Expand Up @@ -192,19 +193,20 @@ def main(_) -> None:
# Apply batching (must apply batching only after filtering)
dataloader = data.apply_batch(dataloader, config.data.batch_size)

_ = train_tf_lib.train_and_evaluate(
train_as_ensemble=config.train_stage_2_as_ensemble,
dataloader=dataloader,
model_params=model_params,
num_splits=config.data.num_splits,
ood_ratio=config.data.ood_ratio,
output_dir=output_dir,
experiment_name=experiment_name,
save_model_checkpoints=config.training.save_model_checkpoints,
save_best_model=config.training.save_best_model,
early_stopping=config.training.early_stopping,
ensemble_dir=FLAGS.ensemble_dir,
example_id_to_bias_table=example_id_to_bias_table)
with strategy.scope():
_ = train_tf_lib.train_and_evaluate(
train_as_ensemble=config.train_stage_2_as_ensemble,
dataloader=dataloader,
model_params=model_params,
num_splits=config.data.num_splits,
ood_ratio=config.data.ood_ratio,
output_dir=output_dir,
experiment_name=experiment_name,
save_model_checkpoints=config.training.save_model_checkpoints,
save_best_model=config.training.save_best_model,
early_stopping=config.training.early_stopping,
ensemble_dir=FLAGS.ensemble_dir,
example_id_to_bias_table=example_id_to_bias_table)


if __name__ == '__main__':
Expand Down
57 changes: 27 additions & 30 deletions experimental/shoshin/train_tf_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,20 +352,19 @@ def evaluate_model(
logging.info(results)


def _create_strategy(
def create_strategy(
tpu_bns: Optional[str] = '') -> tf.distribute.Strategy:
"""Creates distribution strategy used in training.
"""Creates distribution strategy based on machine available.
Args:
tpu_bns: The bns address of the first TPU worker.
tpu_bns: A string of the Headless TPU Worker's BNS address.
Returns:
tf.distribute.Strategy
"""
use_tpu = False
if tpu_bns: # Use tpu if tpu_bns is specified
use_tpu = True
else: # Use tpu if tpu is available
# Set use_tpu as true if tpu bns is set or tpu is available.
use_tpu = bool(tpu_bns)
if not use_tpu:
visible_devices = tf.config.get_visible_devices()
for device in visible_devices:
if device.device_type == 'TPU':
Expand Down Expand Up @@ -402,29 +401,27 @@ def init_model(
Returns:
Initialized TwoHeadedOutputModel model.
"""
strategy = _create_strategy(tpu_bns=model_params.tpu_bns)
with strategy.scope():
model_class = models.get_model(model_params.model_name)
base_model = model_class(model_params=model_params)

two_head_model = TwoHeadedOutputModel(
model=base_model,
num_subgroups=model_params.num_subgroups,
subgroup_sizes=model_params.subgroup_sizes,
worst_group_label=model_params.worst_group_label,
train_bias=model_params.train_bias,
name=experiment_name,
do_reweighting=model_params.do_reweighting,
reweighting_signal=model_params.reweighting_signal,
reweighting_lambda=model_params.reweighting_lambda,
error_percentile_threshold=model_params
.reweighting_error_percentile_threshold)

if model_params.train_bias or model_params.do_reweighting:
if example_id_to_bias_table:
two_head_model.update_id_to_bias_table(example_id_to_bias_table)

two_head_model = compile_model(two_head_model, model_params)
model_class = models.get_model(model_params.model_name)
base_model = model_class(model_params=model_params)

two_head_model = TwoHeadedOutputModel(
model=base_model,
num_subgroups=model_params.num_subgroups,
subgroup_sizes=model_params.subgroup_sizes,
worst_group_label=model_params.worst_group_label,
train_bias=model_params.train_bias,
name=experiment_name,
do_reweighting=model_params.do_reweighting,
reweighting_signal=model_params.reweighting_signal,
reweighting_lambda=model_params.reweighting_lambda,
error_percentile_threshold=model_params
.reweighting_error_percentile_threshold)

if model_params.train_bias or model_params.do_reweighting:
if example_id_to_bias_table:
two_head_model.update_id_to_bias_table(example_id_to_bias_table)

two_head_model = compile_model(two_head_model, model_params)
return two_head_model


Expand Down

0 comments on commit 3c78dd6

Please sign in to comment.