Skip to content

Commit

Permalink
♻️ Use YAML tags instead of unsafe loader
Browse files Browse the repository at this point in the history
  • Loading branch information
shnizzedy committed Sep 5, 2024
1 parent 7391745 commit 9c34a1e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
53 changes: 38 additions & 15 deletions CPAC/utils/bids_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import os
import re
import sys
from typing import Any, Callable
from typing import Any, Callable, Optional
from warnings import warn

from botocore.exceptions import BotoCoreError
Expand Down Expand Up @@ -1002,15 +1002,33 @@ def insert_entity(resource, key, value):
return "_".join([*new_entities[0], f"{key}-{value}", *new_entities[1], suff])


def load_yaml_config(config_filename: str, aws_input_creds, safe_load: bool = False):
def apply_modifications(
yaml_contents: str, modifications: Optional[list[Callable[[str], str]]]
) -> str:
"""Apply modification functions to YAML contents"""
if modifications:
for modification in modifications:
yaml_contents = modification(yaml_contents)
return yaml_contents


def load_yaml_config(
config_filename: str,
aws_input_creds,
modifications: Optional[list[Callable[[str], str]]] = None,
) -> dict | list | str:
"""Load a YAML config file, possibly from AWS, with modifications applied.
`modifications` should be a list of functions that take a single string argument (the loaded YAML contents) and return a single string argument (the modified YAML contents).
"""
if config_filename.lower().startswith("data:"):
try:
header, encoded = config_filename.split(",", 1)
config_content = b64decode(encoded)
if safe_load:
return yaml.safe_load(config_content)
return yaml.load(config_content, Loader=yaml.loader.BaseLoader)
except:
_header, encoded = config_filename.split(",", 1)
config_content = apply_modifications(
b64decode(encoded).decode("utf-8"), modifications
)
return yaml.safe_load(config_content)
except Exception:
msg = f"Error! Could not find load config from data URI {config_filename}"
raise SpecifiedBotoCoreError(msg=msg)

Expand All @@ -1034,9 +1052,8 @@ def load_yaml_config(config_filename: str, aws_input_creds, safe_load: bool = Fa
config_filename = os.path.realpath(config_filename)

try:
if safe_load:
return yaml.safe_load(open(config_filename, "r"))
return yaml.load(open(config_filename, "r"), Loader=yaml.loader.BaseLoader)
with open(config_filename, "r") as _f:
return yaml.safe_load(apply_modifications(_f.read(), modifications))
except IOError:
msg = f"Error! Could not find config file {config_filename}"
raise FileNotFoundError(msg)
Expand Down Expand Up @@ -1145,6 +1162,13 @@ def _check_value_type(
)


def coerce_data_config_strings(contents: str) -> str:
"""Coerge `subject_id` and `unique_id` to be strings."""
for key in ["subject_id: ", "unique_id: "]:
contents = contents.replace(key, f"{key}!!str ")
return contents


def load_cpac_data_config(data_config_file, participant_labels, aws_input_creds):
"""
Loads the file as a check to make sure it is available and readable.
Expand All @@ -1162,10 +1186,9 @@ def load_cpac_data_config(data_config_file, participant_labels, aws_input_creds)
-------
list
"""
sub_list: list[dict[str, str]] = load_yaml_config(data_config_file, aws_input_creds)

if _check_value_type(sub_list, ["subject_id", "unique_id"], int, any):
sub_list = load_yaml_config(data_config_file, aws_input_creds, safe_load=False)
sub_list: list[dict[str, str]] = load_yaml_config(
data_config_file, aws_input_creds, modifications=[coerce_data_config_strings]
)

if participant_labels:
sub_list = [
Expand Down
4 changes: 1 addition & 3 deletions CPAC/utils/tests/test_bids_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,7 @@ def test_load_data_config_with_ints() -> None:
"utils/tests/configs/github_2144.yml"
)
# make sure there are ints in the test data
assert _check_value_type(
load_yaml_config(str(data_config_file), None, safe_load=True)
)
assert _check_value_type(load_yaml_config(str(data_config_file), None))
# make sure there aren't ints when it's loaded through the loader
assert not _check_value_type(
load_cpac_data_config(str(data_config_file), None, None)
Expand Down

0 comments on commit 9c34a1e

Please sign in to comment.