-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
handle nan and inf float values #249
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,8 @@ | |
|
||
import numpy as np | ||
import pandas as pd | ||
import pyarrow as pa | ||
import pyarrow.parquet as pq | ||
import pytest | ||
from pydantic import BaseModel | ||
|
||
|
@@ -984,12 +986,19 @@ def test_from_csv_null_collect(tmp_dir, test_session): | |
# See https://github.com/xzkostyan/clickhouse-sqlalchemy/issues/189. | ||
df = pd.DataFrame(DF_DATA) | ||
height = [70, 65, None, 72, 68] | ||
gender = ["f", "m", None, "m", "f"] | ||
df["height"] = height | ||
df["gender"] = gender | ||
path = tmp_dir / "test.csv" | ||
df.to_csv(path, index=False) | ||
dc = DataChain.from_csv(path.as_uri(), object_name="csv", session=test_session) | ||
for i, row in enumerate(dc.collect()): | ||
assert row[1].height == height[i] | ||
# None value in numeric column will get converted to nan. | ||
if not height[i]: | ||
assert math.isnan(row[1].height) | ||
else: | ||
assert row[1].height == height[i] | ||
assert row[1].gender == gender[i] | ||
|
||
|
||
def test_from_csv_nrows(tmp_dir, test_session): | ||
|
@@ -1533,3 +1542,48 @@ def test_mutate_with_expression_without_type(catalog): | |
assert str(excinfo.value) == ( | ||
"Error for column new: Cannot infer type with expression id - :id_1" | ||
) | ||
|
||
|
||
def test_from_values_nan_inf(tmp_dir, catalog): | ||
vals = [float("nan"), float("inf"), float("-inf")] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's impressive! |
||
dc = DataChain.from_values(vals=vals) | ||
res = list(dc.collect("vals")) | ||
assert np.isnan(res[0]) | ||
assert np.isposinf(res[1]) | ||
assert np.isneginf(res[2]) | ||
|
||
|
||
def test_from_pandas_nan_inf(tmp_dir, catalog): | ||
vals = [float("nan"), float("inf"), float("-inf")] | ||
df = pd.DataFrame({"vals": vals}) | ||
dc = DataChain.from_pandas(df) | ||
res = list(dc.collect("vals")) | ||
assert np.isnan(res[0]) | ||
assert np.isposinf(res[1]) | ||
assert np.isneginf(res[2]) | ||
|
||
|
||
def test_from_parquet_nan_inf(tmp_dir, catalog): | ||
vals = [float("nan"), float("inf"), float("-inf")] | ||
tbl = pa.table({"vals": vals}) | ||
path = tmp_dir / "test.parquet" | ||
pq.write_table(tbl, path) | ||
dc = DataChain.from_parquet(path.as_uri()) | ||
|
||
res = list(dc.collect("vals")) | ||
assert np.isnan(res[0]) | ||
assert np.isposinf(res[1]) | ||
assert np.isneginf(res[2]) | ||
|
||
|
||
def test_from_csv_nan_inf(tmp_dir, catalog): | ||
vals = [float("nan"), float("inf"), float("-inf")] | ||
df = pd.DataFrame({"vals": vals}) | ||
path = tmp_dir / "test.csv" | ||
df.to_csv(path, index=False) | ||
dc = DataChain.from_csv(path.as_uri()) | ||
|
||
res = list(dc.collect("vals")) | ||
assert np.isnan(res[0]) | ||
assert np.isposinf(res[1]) | ||
assert np.isneginf(res[2]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,14 @@ | ||
import math | ||
from typing import Optional | ||
|
||
import pandas as pd | ||
import pytest | ||
from pydantic import BaseModel | ||
|
||
from datachain.lib.dc import DataChain, DatasetMergeError | ||
from datachain.lib.signal_schema import SignalResolvingError | ||
from datachain.sql.types import Float, String | ||
from datachain.sql.types import String | ||
from tests.utils import skip_if_not_sqlite | ||
|
||
|
||
class User(BaseModel): | ||
|
@@ -45,12 +47,12 @@ class TeamMember(BaseModel): | |
|
||
|
||
def test_merge_objects(test_session): | ||
skip_if_not_sqlite() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See the note about joins in https://github.com/iterative/studio/pull/10429 |
||
ch1 = DataChain.from_values(emp=employees, session=test_session) | ||
ch2 = DataChain.from_values(team=team, session=test_session) | ||
ch = ch1.merge(ch2, "emp.person.name", "team.player") | ||
|
||
str_default = String.default_value(test_session.catalog.warehouse.db.dialect) | ||
float_default = Float.default_value(test_session.catalog.warehouse.db.dialect) | ||
|
||
i = 0 | ||
j = 0 | ||
|
@@ -72,8 +74,8 @@ def test_merge_objects(test_session): | |
else: | ||
assert player.player == str_default | ||
assert player.sport == str_default | ||
assert player.weight == float_default | ||
assert player.height == float_default | ||
assert pd.isnull(player.weight) | ||
assert pd.isnull(player.height) | ||
|
||
assert i == len(employees) | ||
assert j == len(team) | ||
|
@@ -103,14 +105,13 @@ def test_merge_similar_objects(test_session): | |
|
||
|
||
def test_merge_values(test_session): | ||
skip_if_not_sqlite() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See the note about joins in https://github.com/iterative/studio/pull/10429 |
||
order_ids = [11, 22, 33, 44] | ||
order_descr = ["water", "water", "paper", "water"] | ||
|
||
delivery_ids = [11, 44] | ||
delivery_time = [24.0, 16.5] | ||
|
||
float_default = Float.default_value(test_session.catalog.warehouse.db.dialect) | ||
|
||
ch1 = DataChain.from_values(id=order_ids, descr=order_descr, session=test_session) | ||
ch2 = DataChain.from_values( | ||
id=delivery_ids, time=delivery_time, session=test_session | ||
|
@@ -137,7 +138,7 @@ def test_merge_values(test_session): | |
assert name == order_descr[i] | ||
i += 1 | ||
|
||
if time != float_default: | ||
if pd.notnull(time): | ||
assert id == delivery_ids[j] | ||
assert time == delivery_time[j] | ||
j += 1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This provides pandas-like null handling for csv files