Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(parsing): normalize nested column names #542

Merged
merged 4 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import re
from collections.abc import Sequence
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Optional
Expand All @@ -13,6 +12,7 @@
from datachain.lib.model_store import ModelStore
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.udf import Generator
from datachain.lib.utils import normalize_col_names

if TYPE_CHECKING:
from datasets.features.features import Features
Expand Down Expand Up @@ -128,7 +128,7 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
signal_schema = _get_datachain_schema(schema)
if signal_schema:
return signal_schema.values
columns = _convert_col_names(col_names) # type: ignore[arg-type]
columns = list(normalize_col_names(col_names).keys()) # type: ignore[arg-type]
hf_schema = _get_hf_schema(schema)
if hf_schema:
return {
Expand All @@ -143,19 +143,6 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
return output


def _convert_col_names(col_names: Sequence[str]) -> list[str]:
default_column = 0
converted_col_names = []
for column in col_names:
column = column.lower()
column = re.sub("[^0-9a-z_]+", "", column)
if not column:
column = f"c{default_column}"
default_column += 1
converted_col_names.append(column)
return converted_col_names


def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: PLR0911
"""Convert pyarrow types to basic types."""
from datetime import datetime
Expand Down
12 changes: 10 additions & 2 deletions src/datachain/lib/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from datetime import datetime
from typing import ClassVar, Union, get_args, get_origin

from pydantic import BaseModel, create_model
from pydantic import BaseModel, Field, create_model

from datachain.lib.model_store import ModelStore
from datachain.lib.utils import normalize_col_names

StandardType = Union[
type[int],
Expand Down Expand Up @@ -60,7 +61,14 @@ def is_chain_type(t: type) -> bool:


def dict_to_data_model(name: str, data_dict: dict[str, DataType]) -> type[BaseModel]:
fields = {name: (anno, ...) for name, anno in data_dict.items()}
# Gets a map of a normalized_name -> original_name
columns = normalize_col_names(list(data_dict.keys()))
# We reverse if for convenience to original_name -> normalized_name
columns = {v: k for k, v in columns.items()}

fields = {
columns[name]: (anno, Field(alias=name)) for name, anno in data_dict.items()
}
return create_model(
name,
__base__=(DataModel,), # type: ignore[call-overload]
Expand Down
30 changes: 30 additions & 0 deletions src/datachain/lib/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
from abc import ABC, abstractmethod
from collections.abc import Sequence


class AbstractUDF(ABC):
Expand Down Expand Up @@ -28,3 +30,31 @@ def __init__(self, message):
class DataChainColumnError(DataChainParamsError):
def __init__(self, col_name, msg):
super().__init__(f"Error for column {col_name}: {msg}")


def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
gen_col_counter = 0
new_col_names = {}
org_col_names = set(col_names)

for org_column in col_names:
new_column = org_column.lower()
new_column = re.sub("[^0-9a-z]+", "_", new_column)
new_column = new_column.strip("_")

generated_column = new_column

while (
not generated_column.isidentifier()
or generated_column in new_col_names
or (generated_column != org_column and generated_column in org_col_names)
):
if new_column:
generated_column = f"c{gen_col_counter}_{new_column}"
else:
generated_column = f"c{gen_col_counter}"
gen_col_counter += 1

new_col_names[generated_column] = org_column

return new_col_names
14 changes: 11 additions & 3 deletions tests/unit/lib/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,21 @@ def test_parquet_convert_column_names():
("dot.notation.col", pa.int32()),
("with-dashes", pa.int32()),
("with spaces", pa.int32()),
("with-multiple--dashes", pa.int32()),
("with__underscores", pa.int32()),
("__leading__underscores", pa.int32()),
("trailing__underscores__", pa.int32()),
]
)
assert list(schema_to_output(schema)) == [
"uppercasecol",
"dotnotationcol",
"withdashes",
"withspaces",
"dot_notation_col",
"with_dashes",
"with_spaces",
"with_multiple_dashes",
"with_underscores",
"leading_underscores",
"trailing_underscores",
]


Expand Down
35 changes: 34 additions & 1 deletion tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@
"city": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"],
}

DF_DATA_NESTED_NOT_NORMALIZED = {
"nAmE": [
{"first-SELECT": "Alice", "l--as@t": "Smith"},
{"l--as@t": "Jones", "first-SELECT": "Bob"},
{"first-SELECT": "Charlie", "l--as@t": "Brown"},
{"first-SELECT": "David", "l--as@t": "White"},
{"first-SELECT": "Eva", "l--as@t": "Black"},
],
"AgE": [25, 30, 35, 40, 45],
"citY": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"],
}

DF_OTHER_DATA = {
"last_name": ["Smith", "Jones"],
"country": ["USA", "Russia"],
Expand Down Expand Up @@ -272,7 +284,9 @@ def test_listings(test_session, tmp_dir):
assert listing.expires
assert listing.version == 1
assert listing.num_objects == 1
assert listing.size == 2912
# Exact number if unreliable here since it depends on the PyArrow version
assert listing.size > 1000
assert listing.size < 5000
assert listing.status == 4


Expand Down Expand Up @@ -988,6 +1002,25 @@ def test_parse_tabular_format(tmp_dir, test_session):
assert df1.equals(df)


def test_parse_nested_json(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA_NESTED_NOT_NORMALIZED)
path = tmp_dir / "test.jsonl"
path.write_text(df.to_json(orient="records", lines=True))
dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular(
format="json"
)
# Field names are normalized, values are preserved
# E.g. nAmE -> name, l--as@t -> l_as_t, etc
df1 = dc.select("name", "age", "city").to_pandas()

assert df1["name"]["first_select"].to_list() == [
d["first-SELECT"] for d in df["nAmE"].to_list()
]
assert df1["name"]["l_as_t"].to_list() == [
d["l--as@t"] for d in df["nAmE"].to_list()
]


def test_parse_tabular_partitions(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.parquet"
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/lib/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel

from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.utils import normalize_col_names
from datachain.sql.types import JSON, Array, String


Expand Down Expand Up @@ -56,3 +57,72 @@ def test_convert_type_to_datachain_array(typ, expected):
def test_convert_type_to_datachain_error(typ):
with pytest.raises(TypeError):
python_to_sql(typ)


def test_normalize_column_names():
res = normalize_col_names(
[
"UpperCase",
"_underscore_start",
"double__underscore",
"1start_with_number",
"не_ascii_start",
" space_start",
"space_end ",
"dash-end-",
"-dash-start",
"--multiple--dash--",
"-_ mix_ -dash_ -",
"__2digit_after_uderscore",
"",
"_-_- _---_ _",
"_-_- _---_ _1",
]
)
assert list(res.keys()) == [
"uppercase",
"underscore_start",
"double_underscore",
"c0_1start_with_number",
"ascii_start",
"space_start",
"space_end",
"dash_end",
"dash_start",
"multiple_dash",
"mix_dash",
"c1_2digit_after_uderscore",
"c2",
"c3",
"c4_1",
]


def test_normalize_column_names_case_repeat():
res = normalize_col_names(["UpperCase", "UpPerCase"])

assert list(res.keys()) == ["uppercase", "c0_uppercase"]


def test_normalize_column_names_exists_after_normalize():
res = normalize_col_names(["1digit", "c0_1digit"])

assert list(res.keys()) == ["c1_1digit", "c0_1digit"]


def test_normalize_column_names_normalized_repeat():
res = normalize_col_names(["column", "_column"])

assert list(res.keys()) == ["column", "c0_column"]


def test_normalize_column_names_normalized_case_repeat():
res = normalize_col_names(["CoLuMn", "_column"])

assert res == {"column": "CoLuMn", "c0_column": "_column"}


def test_normalize_column_names_repeat_generated_after_normalize():
res = normalize_col_names(["c0_CoLuMn", "_column", "column"])

assert res == {"c0_column": "c0_CoLuMn", "c1_column": "_column", "column": "column"}