Skip to content

Commit

Permalink
fix(parsing): normalize nested column names (#542)
Browse files Browse the repository at this point in the history
* fix(parsing): normalize nested column names

* apply suggestions from PR review: simplify regexp

Co-authored-by: Helio Machado <[email protected]>

* apply suggestion from PR review: simplify normalization func

Co-authored-by: Helio Machado <[email protected]>

* fix flakey listings test since pyarrow got upgraded

---------

Co-authored-by: Helio Machado <[email protected]>
  • Loading branch information
shcheklein and 0x2b3bfa0 authored Oct 28, 2024
1 parent 0eabe20 commit 7146527
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 21 deletions.
17 changes: 2 additions & 15 deletions src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import re
from collections.abc import Sequence
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Any, Optional
Expand All @@ -13,6 +12,7 @@
from datachain.lib.model_store import ModelStore
from datachain.lib.signal_schema import SignalSchema
from datachain.lib.udf import Generator
from datachain.lib.utils import normalize_col_names

if TYPE_CHECKING:
from datasets.features.features import Features
Expand Down Expand Up @@ -128,7 +128,7 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
signal_schema = _get_datachain_schema(schema)
if signal_schema:
return signal_schema.values
columns = _convert_col_names(col_names) # type: ignore[arg-type]
columns = list(normalize_col_names(col_names).keys()) # type: ignore[arg-type]
hf_schema = _get_hf_schema(schema)
if hf_schema:
return {
Expand All @@ -143,19 +143,6 @@ def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = Non
return output


def _convert_col_names(col_names: Sequence[str]) -> list[str]:
default_column = 0
converted_col_names = []
for column in col_names:
column = column.lower()
column = re.sub("[^0-9a-z_]+", "", column)
if not column:
column = f"c{default_column}"
default_column += 1
converted_col_names.append(column)
return converted_col_names


def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: PLR0911
"""Convert pyarrow types to basic types."""
from datetime import datetime
Expand Down
12 changes: 10 additions & 2 deletions src/datachain/lib/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from datetime import datetime
from typing import ClassVar, Union, get_args, get_origin

from pydantic import BaseModel, create_model
from pydantic import BaseModel, Field, create_model

from datachain.lib.model_store import ModelStore
from datachain.lib.utils import normalize_col_names

StandardType = Union[
type[int],
Expand Down Expand Up @@ -60,7 +61,14 @@ def is_chain_type(t: type) -> bool:


def dict_to_data_model(name: str, data_dict: dict[str, DataType]) -> type[BaseModel]:
fields = {name: (anno, ...) for name, anno in data_dict.items()}
# Gets a map of a normalized_name -> original_name
columns = normalize_col_names(list(data_dict.keys()))
# We reverse if for convenience to original_name -> normalized_name
columns = {v: k for k, v in columns.items()}

fields = {
columns[name]: (anno, Field(alias=name)) for name, anno in data_dict.items()
}
return create_model(
name,
__base__=(DataModel,), # type: ignore[call-overload]
Expand Down
30 changes: 30 additions & 0 deletions src/datachain/lib/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import re
from abc import ABC, abstractmethod
from collections.abc import Sequence


class AbstractUDF(ABC):
Expand Down Expand Up @@ -28,3 +30,31 @@ def __init__(self, message):
class DataChainColumnError(DataChainParamsError):
def __init__(self, col_name, msg):
super().__init__(f"Error for column {col_name}: {msg}")


def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
gen_col_counter = 0
new_col_names = {}
org_col_names = set(col_names)

for org_column in col_names:
new_column = org_column.lower()
new_column = re.sub("[^0-9a-z]+", "_", new_column)
new_column = new_column.strip("_")

generated_column = new_column

while (
not generated_column.isidentifier()
or generated_column in new_col_names
or (generated_column != org_column and generated_column in org_col_names)
):
if new_column:
generated_column = f"c{gen_col_counter}_{new_column}"
else:
generated_column = f"c{gen_col_counter}"
gen_col_counter += 1

new_col_names[generated_column] = org_column

return new_col_names
14 changes: 11 additions & 3 deletions tests/unit/lib/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,21 @@ def test_parquet_convert_column_names():
("dot.notation.col", pa.int32()),
("with-dashes", pa.int32()),
("with spaces", pa.int32()),
("with-multiple--dashes", pa.int32()),
("with__underscores", pa.int32()),
("__leading__underscores", pa.int32()),
("trailing__underscores__", pa.int32()),
]
)
assert list(schema_to_output(schema)) == [
"uppercasecol",
"dotnotationcol",
"withdashes",
"withspaces",
"dot_notation_col",
"with_dashes",
"with_spaces",
"with_multiple_dashes",
"with_underscores",
"leading_underscores",
"trailing_underscores",
]


Expand Down
35 changes: 34 additions & 1 deletion tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@
"city": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"],
}

DF_DATA_NESTED_NOT_NORMALIZED = {
"nAmE": [
{"first-SELECT": "Alice", "l--as@t": "Smith"},
{"l--as@t": "Jones", "first-SELECT": "Bob"},
{"first-SELECT": "Charlie", "l--as@t": "Brown"},
{"first-SELECT": "David", "l--as@t": "White"},
{"first-SELECT": "Eva", "l--as@t": "Black"},
],
"AgE": [25, 30, 35, 40, 45],
"citY": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"],
}

DF_OTHER_DATA = {
"last_name": ["Smith", "Jones"],
"country": ["USA", "Russia"],
Expand Down Expand Up @@ -272,7 +284,9 @@ def test_listings(test_session, tmp_dir):
assert listing.expires
assert listing.version == 1
assert listing.num_objects == 1
assert listing.size == 2912
# Exact number if unreliable here since it depends on the PyArrow version
assert listing.size > 1000
assert listing.size < 5000
assert listing.status == 4


Expand Down Expand Up @@ -988,6 +1002,25 @@ def test_parse_tabular_format(tmp_dir, test_session):
assert df1.equals(df)


def test_parse_nested_json(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA_NESTED_NOT_NORMALIZED)
path = tmp_dir / "test.jsonl"
path.write_text(df.to_json(orient="records", lines=True))
dc = DataChain.from_storage(path.as_uri(), session=test_session).parse_tabular(
format="json"
)
# Field names are normalized, values are preserved
# E.g. nAmE -> name, l--as@t -> l_as_t, etc
df1 = dc.select("name", "age", "city").to_pandas()

assert df1["name"]["first_select"].to_list() == [
d["first-SELECT"] for d in df["nAmE"].to_list()
]
assert df1["name"]["l_as_t"].to_list() == [
d["l--as@t"] for d in df["nAmE"].to_list()
]


def test_parse_tabular_partitions(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.parquet"
Expand Down
70 changes: 70 additions & 0 deletions tests/unit/lib/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel

from datachain.lib.convert.python_to_sql import python_to_sql
from datachain.lib.utils import normalize_col_names
from datachain.sql.types import JSON, Array, String


Expand Down Expand Up @@ -56,3 +57,72 @@ def test_convert_type_to_datachain_array(typ, expected):
def test_convert_type_to_datachain_error(typ):
with pytest.raises(TypeError):
python_to_sql(typ)


def test_normalize_column_names():
res = normalize_col_names(
[
"UpperCase",
"_underscore_start",
"double__underscore",
"1start_with_number",
"не_ascii_start",
" space_start",
"space_end ",
"dash-end-",
"-dash-start",
"--multiple--dash--",
"-_ mix_ -dash_ -",
"__2digit_after_uderscore",
"",
"_-_- _---_ _",
"_-_- _---_ _1",
]
)
assert list(res.keys()) == [
"uppercase",
"underscore_start",
"double_underscore",
"c0_1start_with_number",
"ascii_start",
"space_start",
"space_end",
"dash_end",
"dash_start",
"multiple_dash",
"mix_dash",
"c1_2digit_after_uderscore",
"c2",
"c3",
"c4_1",
]


def test_normalize_column_names_case_repeat():
res = normalize_col_names(["UpperCase", "UpPerCase"])

assert list(res.keys()) == ["uppercase", "c0_uppercase"]


def test_normalize_column_names_exists_after_normalize():
res = normalize_col_names(["1digit", "c0_1digit"])

assert list(res.keys()) == ["c1_1digit", "c0_1digit"]


def test_normalize_column_names_normalized_repeat():
res = normalize_col_names(["column", "_column"])

assert list(res.keys()) == ["column", "c0_column"]


def test_normalize_column_names_normalized_case_repeat():
res = normalize_col_names(["CoLuMn", "_column"])

assert res == {"column": "CoLuMn", "c0_column": "_column"}


def test_normalize_column_names_repeat_generated_after_normalize():
res = normalize_col_names(["c0_CoLuMn", "_column", "column"])

assert res == {"c0_column": "c0_CoLuMn", "c1_column": "_column", "column": "column"}

0 comments on commit 7146527

Please sign in to comment.