Skip to content

Commit

Permalink
Implement database default values (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadatour committed Aug 15, 2024
1 parent ba4d848 commit 422d5e8
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 22 deletions.
40 changes: 19 additions & 21 deletions src/datachain/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from types import MappingProxyType
from typing import Any, Union

import sqlalchemy as sa
from sqlalchemy import TypeDecorator, types

_registry: dict[str, "TypeConverter"] = {}
Expand Down Expand Up @@ -395,7 +396,11 @@ def uint64(self, value):
return value

def float(self, value):
return float("nan") if value is None else value
if value is None:
return float("nan")
if isinstance(value, str) and value.lower() == "nan":
return float("nan")
return value

def float32(self, value):
return self.float(value)
Expand All @@ -414,10 +419,9 @@ def json(self, value):
def datetime(self, value):
return value

def uuid(self, value):
return value

def binary(self, value):
if isinstance(value, str):
return value.encode()
return value


Expand Down Expand Up @@ -499,34 +503,31 @@ def json(self):
def datetime(self):
return None

def uuid(self):
return None

def binary(self):
return None


class DBDefaults:
def string(self):
return None
return sa.text("''")

def boolean(self):
return None
return sa.text("False")

def int(self):
return None
return sa.text("0")

def int32(self):
return None
return self.int()

def int64(self):
return None
return self.int()

def uint64(self):
return None
return self.int()

def float(self):
return float("nan")
return sa.text("NaN")

def float32(self):
return self.float()
Expand All @@ -535,19 +536,16 @@ def float64(self):
return self.float()

def array(self):
return None
return sa.text("'[]'")

def json(self):
return None
return sa.text("'{}'")

def datetime(self):
return None

def uuid(self):
return None
return sa.text("'1970-01-01 00:00:00'")

def binary(self):
return None
return sa.text("''")


TYPES = [
Expand Down
51 changes: 50 additions & 1 deletion tests/unit/test_data_storage.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from datetime import datetime
import math
from datetime import datetime, timezone
from typing import Any

import pytest
import sqlalchemy

from datachain.sql.types import (
JSON,
Array,
Binary,
Boolean,
DateTime,
Float,
Float32,
Float64,
Int,
Int32,
Int64,
String,
UInt64,
)
from tests.utils import DEFAULT_TREE, TARRED_TREE, create_tar_dataset

Expand Down Expand Up @@ -165,3 +171,46 @@ def run_convert_type(value, sql_type):
# error, float to int in list
with pytest.raises(ValueError):
run_convert_type([1.5, 1], Array(Int))


@pytest.mark.parametrize(
"col_type,default_value",
[
[String(), ""],
[Boolean(), False],
[Int(), 0],
[Int32(), 0],
[Int64(), 0],
[UInt64(), 0],
[Float(), lambda val: math.isnan(val)],
[Float32(), lambda val: math.isnan(val)],
[Float64(), lambda val: math.isnan(val)],
[Array(Int), []],
[JSON(), "{}"],
[DateTime(), datetime(1970, 1, 1, 0, 0, tzinfo=timezone.utc)],
[Binary(), b""],
],
)
def test_db_defaults(col_type, default_value, catalog):
warehouse = catalog.warehouse

table_col = sqlalchemy.Column(
"val",
col_type,
nullable=False,
server_default=col_type.db_default_value(warehouse.db.dialect),
)
table = warehouse.create_udf_table([table_col])
warehouse.insert_rows(table, [{"sys__id": 1}])
warehouse.insert_rows_done(table)

query = sqlalchemy.Select(table_col).select_from(table)

values = [row[0] for row in warehouse.dataset_rows_select(query)]
assert len(values) == 1
if callable(default_value):
assert default_value(values[0])
else:
assert values[0] == default_value

warehouse.db.drop_table(table)

0 comments on commit 422d5e8

Please sign in to comment.