Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle nan and inf float values #249

Merged
merged 5 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,8 @@ def from_csv(
dc = DataChain.from_csv("s3://mybucket/dir")
```
"""
from pyarrow.csv import ParseOptions, ReadOptions
from pandas.io.parsers.readers import STR_NA_VALUES
from pyarrow.csv import ConvertOptions, ParseOptions, ReadOptions
from pyarrow.dataset import CsvFileFormat

chain = DataChain.from_storage(path, **kwargs)
Expand All @@ -1420,7 +1421,14 @@ def from_csv(

parse_options = ParseOptions(delimiter=delimiter)
read_options = ReadOptions(column_names=column_names)
format = CsvFileFormat(parse_options=parse_options, read_options=read_options)
convert_options = ConvertOptions(
strings_can_be_null=True, null_values=STR_NA_VALUES
)
Comment on lines +1424 to +1426
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This provides pandas-like null handling for csv files

format = CsvFileFormat(
parse_options=parse_options,
read_options=read_options,
convert_options=convert_options,
)
return chain.parse_tabular(
output=output,
object_name=object_name,
Expand Down
14 changes: 7 additions & 7 deletions src/datachain/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def on_read_convert(self, value, dialect):


class Float(SQLType):
impl = types.INTEGER
impl = types.FLOAT

@property
def python_type(self):
Expand Down Expand Up @@ -328,13 +328,13 @@ def uint64(self, value):
return value

def float(self, value):
return value
return float("nan") if value is None else value

def float32(self, value):
return value
return self.float(value)

def float64(self, value):
return value
return self.float(value)

def array(self, value, item_type, dialect):
if value is None or item_type is None:
Expand Down Expand Up @@ -415,13 +415,13 @@ def uint64(self):
return None

def float(self):
return None
return float("nan")

def float32(self):
return None
return self.float()

def float64(self):
return None
return self.float()

def array(self):
return None
Expand Down
49 changes: 12 additions & 37 deletions tests/func/test_dataset_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2573,35 +2573,14 @@ def generate_subobjects(self, path):
q = DatasetQuery(name="dogs_with_rows_and_signals", catalog=catalog)
result = q.to_db_records()

col_values = [
(
r["path"],
r["int_col"],
r["int_col_32"],
r["int_col_64"],
r["bool_col"],
r["float_col"],
r["float_col_32"],
r["float_col_64"],
r["json_col"],
r["datetime_col"],
r["binary_col"],
r["array_col"],
r["array_col_nested"],
r["array_col_32"],
r["array_col_64"],
)
for r in result
]

col_values.sort(key=lambda x: x[0])

assert col_values == [
("dogs/dog1/subobject", *new_col_values_empty),
("dogs/dog2/subobject", *new_col_values_empty),
("dogs/dog3/subobject", *new_col_values_empty),
("dogs/others/dog4/subobject", *new_col_values_empty),
]
for row in result:
for i, col in enumerate(new_columns):
val = row[col]
expected = new_col_values_empty[i]
if isinstance(expected, float) and math.isnan(expected):
assert math.isnan(val)
else:
assert val == expected


@pytest.mark.parametrize(
Expand Down Expand Up @@ -3087,7 +3066,6 @@ def test_group_by(cloud_test_catalog, cloud_type, dogs_dataset):
@pytest.mark.parametrize("tree", [WEBFORMAT_TREE], indirect=True)
def test_json_loader(cloud_test_catalog):
catalog = cloud_test_catalog.catalog
dialect = catalog.warehouse.db.dialect

@udf(
params=(C.name,),
Expand Down Expand Up @@ -3116,16 +3094,12 @@ def attach_json(rows):
if json_data and ext != ".json":
signals.append([json_data.get(k) for k in json_output])
else:
signals.append(None)
signals.append([None, None])

return signals

col_default_values = tuple(t.default_value(dialect) for t in json_output.values())

expected = [
("f1.json", col_default_values[0], col_default_values[1]),
("f1.raw", 0.001, "deadbeef"),
("f2.json", col_default_values[0], col_default_values[1]),
("f2.raw", 0.005, "foobar"),
]

Expand All @@ -3134,12 +3108,13 @@ def attach_json(rows):
.mutate(name=pathfunc.name(C.path))
.add_signals(split_name)
.add_signals(attach_json, partition_by=C.basename)
.filter(C.glob(C.name, "*.raw"))
.select(C.name, C.similarity, C.md5)
.order_by(C.name)
)
assert q.count() == 4
assert q.count() == 2
res = q.db_results()
assert len(res) == 4
assert len(res) == 2
assert [r[0] for r in res] == [r[0] for r in expected]
assert [r[1] for r in res] == pytest.approx([r[1] for r in expected])
assert [r[2] for r in res] == [r[2] for r in expected]
Expand Down
56 changes: 55 additions & 1 deletion tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
from pydantic import BaseModel

Expand Down Expand Up @@ -984,12 +986,19 @@ def test_from_csv_null_collect(tmp_dir, test_session):
# See https://github.com/xzkostyan/clickhouse-sqlalchemy/issues/189.
df = pd.DataFrame(DF_DATA)
height = [70, 65, None, 72, 68]
gender = ["f", "m", None, "m", "f"]
df["height"] = height
df["gender"] = gender
path = tmp_dir / "test.csv"
df.to_csv(path, index=False)
dc = DataChain.from_csv(path.as_uri(), object_name="csv", session=test_session)
for i, row in enumerate(dc.collect()):
assert row[1].height == height[i]
# None value in numeric column will get converted to nan.
if not height[i]:
assert math.isnan(row[1].height)
else:
assert row[1].height == height[i]
assert row[1].gender == gender[i]


def test_from_csv_nrows(tmp_dir, test_session):
Expand Down Expand Up @@ -1533,3 +1542,48 @@ def test_mutate_with_expression_without_type(catalog):
assert str(excinfo.value) == (
"Error for column new: Cannot infer type with expression id - :id_1"
)


def test_from_values_nan_inf(tmp_dir, catalog):
vals = [float("nan"), float("inf"), float("-inf")]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's impressive!

dc = DataChain.from_values(vals=vals)
res = list(dc.collect("vals"))
assert np.isnan(res[0])
assert np.isposinf(res[1])
assert np.isneginf(res[2])


def test_from_pandas_nan_inf(tmp_dir, catalog):
vals = [float("nan"), float("inf"), float("-inf")]
df = pd.DataFrame({"vals": vals})
dc = DataChain.from_pandas(df)
res = list(dc.collect("vals"))
assert np.isnan(res[0])
assert np.isposinf(res[1])
assert np.isneginf(res[2])


def test_from_parquet_nan_inf(tmp_dir, catalog):
vals = [float("nan"), float("inf"), float("-inf")]
tbl = pa.table({"vals": vals})
path = tmp_dir / "test.parquet"
pq.write_table(tbl, path)
dc = DataChain.from_parquet(path.as_uri())

res = list(dc.collect("vals"))
assert np.isnan(res[0])
assert np.isposinf(res[1])
assert np.isneginf(res[2])


def test_from_csv_nan_inf(tmp_dir, catalog):
vals = [float("nan"), float("inf"), float("-inf")]
df = pd.DataFrame({"vals": vals})
path = tmp_dir / "test.csv"
df.to_csv(path, index=False)
dc = DataChain.from_csv(path.as_uri())

res = list(dc.collect("vals"))
assert np.isnan(res[0])
assert np.isposinf(res[1])
assert np.isneginf(res[2])
15 changes: 8 additions & 7 deletions tests/unit/lib/test_datachain_merge.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import math
from typing import Optional

import pandas as pd
import pytest
from pydantic import BaseModel

from datachain.lib.dc import DataChain, DatasetMergeError
from datachain.lib.signal_schema import SignalResolvingError
from datachain.sql.types import Float, String
from datachain.sql.types import String
from tests.utils import skip_if_not_sqlite


class User(BaseModel):
Expand Down Expand Up @@ -45,12 +47,12 @@ class TeamMember(BaseModel):


def test_merge_objects(test_session):
skip_if_not_sqlite()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ch1 = DataChain.from_values(emp=employees, session=test_session)
ch2 = DataChain.from_values(team=team, session=test_session)
ch = ch1.merge(ch2, "emp.person.name", "team.player")

str_default = String.default_value(test_session.catalog.warehouse.db.dialect)
float_default = Float.default_value(test_session.catalog.warehouse.db.dialect)

i = 0
j = 0
Expand All @@ -72,8 +74,8 @@ def test_merge_objects(test_session):
else:
assert player.player == str_default
assert player.sport == str_default
assert player.weight == float_default
assert player.height == float_default
assert pd.isnull(player.weight)
assert pd.isnull(player.height)

assert i == len(employees)
assert j == len(team)
Expand Down Expand Up @@ -103,14 +105,13 @@ def test_merge_similar_objects(test_session):


def test_merge_values(test_session):
skip_if_not_sqlite()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order_ids = [11, 22, 33, 44]
order_descr = ["water", "water", "paper", "water"]

delivery_ids = [11, 44]
delivery_time = [24.0, 16.5]

float_default = Float.default_value(test_session.catalog.warehouse.db.dialect)

ch1 = DataChain.from_values(id=order_ids, descr=order_descr, session=test_session)
ch2 = DataChain.from_values(
id=delivery_ids, time=delivery_time, session=test_session
Expand All @@ -137,7 +138,7 @@ def test_merge_values(test_session):
assert name == order_descr[i]
i += 1

if time != float_default:
if pd.notnull(time):
assert id == delivery_ids[j]
assert time == delivery_time[j]
j += 1
Expand Down
Loading