Skip to content

Commit

Permalink
Adding In-Memory DataChain Option (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
dtulga authored Aug 15, 2024
1 parent 98fb076 commit bbcff85
Show file tree
Hide file tree
Showing 10 changed files with 411 additions and 104 deletions.
2 changes: 2 additions & 0 deletions src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def __init__(
warehouse_ready_callback: Optional[
Callable[["AbstractWarehouse"], None]
] = None,
in_memory: bool = False,
):
datachain_dir = DataChainDir(cache=cache_dir, tmp=tmp_dir)
datachain_dir.init()
Expand All @@ -590,6 +591,7 @@ def __init__(
"tmp_dir": tmp_dir,
}
self._warehouse_ready_callback = warehouse_ready_callback
self.in_memory = in_memory

@cached_property
def warehouse(self) -> "AbstractWarehouse":
Expand Down
125 changes: 75 additions & 50 deletions src/datachain/catalog/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
DISTRIBUTED_IMPORT_PATH = "DATACHAIN_DISTRIBUTED"
DISTRIBUTED_ARG_PREFIX = "DATACHAIN_DISTRIBUTED_ARG_"

IN_MEMORY_ERROR_MESSAGE = "In-memory is only supported on SQLite"

def get_id_generator() -> "AbstractIDGenerator":

def get_id_generator(in_memory: bool = False) -> "AbstractIDGenerator":
id_generator_serialized = os.environ.get(ID_GENERATOR_SERIALIZED)
if id_generator_serialized:
id_generator_obj = deserialize(id_generator_serialized)
Expand All @@ -43,25 +45,31 @@ def get_id_generator() -> "AbstractIDGenerator":
id_generator_import_path = os.environ.get(ID_GENERATOR_IMPORT_PATH)
id_generator_arg_envs = get_envs_by_prefix(ID_GENERATOR_ARG_PREFIX)
# Convert env variable names to keyword argument names by lowercasing them
id_generator_args = {k.lower(): v for k, v in id_generator_arg_envs.items()}

if id_generator_import_path:
# ID generator paths are specified as (for example):
# datachain.data_storage.SQLiteIDGenerator
if "." not in id_generator_import_path:
raise RuntimeError(
f"Invalid {ID_GENERATOR_IMPORT_PATH} import path:"
f"{id_generator_import_path}"
)
module_name, _, class_name = id_generator_import_path.rpartition(".")
id_generator = import_module(module_name)
id_generator_class = getattr(id_generator, class_name)
else:
id_generator_class = SQLiteIDGenerator
id_generator_args: dict[str, Any] = {
k.lower(): v for k, v in id_generator_arg_envs.items()
}

if not id_generator_import_path:
id_generator_args["in_memory"] = in_memory
return SQLiteIDGenerator(**id_generator_args)
if in_memory:
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
# ID generator paths are specified as (for example):
# datachain.data_storage.SQLiteIDGenerator
if "." not in id_generator_import_path:
raise RuntimeError(
f"Invalid {ID_GENERATOR_IMPORT_PATH} import path:"
f"{id_generator_import_path}"
)
module_name, _, class_name = id_generator_import_path.rpartition(".")
id_generator = import_module(module_name)
id_generator_class = getattr(id_generator, class_name)
return id_generator_class(**id_generator_args)


def get_metastore(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractMetastore":
def get_metastore(
id_generator: Optional["AbstractIDGenerator"], in_memory: bool = False
) -> "AbstractMetastore":
metastore_serialized = os.environ.get(METASTORE_SERIALIZED)
if metastore_serialized:
metastore_obj = deserialize(metastore_serialized)
Expand All @@ -78,24 +86,32 @@ def get_metastore(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractMet
metastore_import_path = os.environ.get(METASTORE_IMPORT_PATH)
metastore_arg_envs = get_envs_by_prefix(METASTORE_ARG_PREFIX)
# Convert env variable names to keyword argument names by lowercasing them
metastore_args = {k.lower(): v for k, v in metastore_arg_envs.items()}

if metastore_import_path:
# Metastore paths are specified as (for example):
# datachain.data_storage.SQLiteMetastore
if "." not in metastore_import_path:
raise RuntimeError(
f"Invalid {METASTORE_IMPORT_PATH} import path: {metastore_import_path}"
)
module_name, _, class_name = metastore_import_path.rpartition(".")
metastore = import_module(module_name)
metastore_class = getattr(metastore, class_name)
else:
metastore_class = SQLiteMetastore
metastore_args: dict[str, Any] = {
k.lower(): v for k, v in metastore_arg_envs.items()
}

if not metastore_import_path:
if not isinstance(id_generator, SQLiteIDGenerator):
raise ValueError("SQLiteMetastore can only be used with SQLiteIDGenerator")
metastore_args["in_memory"] = in_memory
return SQLiteMetastore(id_generator, **metastore_args)
if in_memory:
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
# Metastore paths are specified as (for example):
# datachain.data_storage.SQLiteMetastore
if "." not in metastore_import_path:
raise RuntimeError(
f"Invalid {METASTORE_IMPORT_PATH} import path: {metastore_import_path}"
)
module_name, _, class_name = metastore_import_path.rpartition(".")
metastore = import_module(module_name)
metastore_class = getattr(metastore, class_name)
return metastore_class(id_generator, **metastore_args)


def get_warehouse(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractWarehouse":
def get_warehouse(
id_generator: Optional["AbstractIDGenerator"], in_memory: bool = False
) -> "AbstractWarehouse":
warehouse_serialized = os.environ.get(WAREHOUSE_SERIALIZED)
if warehouse_serialized:
warehouse_obj = deserialize(warehouse_serialized)
Expand All @@ -112,20 +128,26 @@ def get_warehouse(id_generator: Optional["AbstractIDGenerator"]) -> "AbstractWar
warehouse_import_path = os.environ.get(WAREHOUSE_IMPORT_PATH)
warehouse_arg_envs = get_envs_by_prefix(WAREHOUSE_ARG_PREFIX)
# Convert env variable names to keyword argument names by lowercasing them
warehouse_args = {k.lower(): v for k, v in warehouse_arg_envs.items()}

if warehouse_import_path:
# Warehouse paths are specified as (for example):
# datachain.data_storage.SQLiteWarehouse
if "." not in warehouse_import_path:
raise RuntimeError(
f"Invalid {WAREHOUSE_IMPORT_PATH} import path: {warehouse_import_path}"
)
module_name, _, class_name = warehouse_import_path.rpartition(".")
warehouse = import_module(module_name)
warehouse_class = getattr(warehouse, class_name)
else:
warehouse_class = SQLiteWarehouse
warehouse_args: dict[str, Any] = {
k.lower(): v for k, v in warehouse_arg_envs.items()
}

if not warehouse_import_path:
if not isinstance(id_generator, SQLiteIDGenerator):
raise ValueError("SQLiteWarehouse can only be used with SQLiteIDGenerator")
warehouse_args["in_memory"] = in_memory
return SQLiteWarehouse(id_generator, **warehouse_args)
if in_memory:
raise RuntimeError(IN_MEMORY_ERROR_MESSAGE)
# Warehouse paths are specified as (for example):
# datachain.data_storage.SQLiteWarehouse
if "." not in warehouse_import_path:
raise RuntimeError(
f"Invalid {WAREHOUSE_IMPORT_PATH} import path: {warehouse_import_path}"
)
module_name, _, class_name = warehouse_import_path.rpartition(".")
warehouse = import_module(module_name)
warehouse_class = getattr(warehouse, class_name)
return warehouse_class(id_generator, **warehouse_args)


Expand All @@ -152,7 +174,9 @@ def get_distributed_class(**kwargs):
return distributed_class(**distributed_args | kwargs)


def get_catalog(client_config: Optional[dict[str, Any]] = None) -> Catalog:
def get_catalog(
client_config: Optional[dict[str, Any]] = None, in_memory: bool = False
) -> Catalog:
"""
Function that creates Catalog instance with appropriate metastore
and warehouse classes. Metastore class can be provided with env variable
Expand All @@ -164,10 +188,11 @@ def get_catalog(client_config: Optional[dict[str, Any]] = None) -> Catalog:
and name of variable after, e.g. if it accepts team_id as kwargs
we can provide DATACHAIN_METASTORE_ARG_TEAM_ID=12345 env variable.
"""
id_generator = get_id_generator()
id_generator = get_id_generator(in_memory=in_memory)
return Catalog(
id_generator=id_generator,
metastore=get_metastore(id_generator),
warehouse=get_warehouse(id_generator),
metastore=get_metastore(id_generator, in_memory=in_memory),
warehouse=get_warehouse(id_generator, in_memory=in_memory),
client_config=client_config,
in_memory=in_memory,
)
20 changes: 20 additions & 0 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def wrapper(*args, **kwargs):
return wrapper


def get_db_file_in_memory(
db_file: Optional[str] = None, in_memory: bool = False
) -> Optional[str]:
"""Get in-memory db_file and check that conflicting arguments are not provided."""
if in_memory:
if db_file and db_file != ":memory:":
raise RuntimeError("A db_file cannot be specified if in_memory is True")
db_file = ":memory:"
return db_file


class SQLiteDatabaseEngine(DatabaseEngine):
dialect = sqlite_dialect

Expand Down Expand Up @@ -264,7 +275,10 @@ def __init__(
table_prefix: Optional[str] = None,
skip_db_init: bool = False,
db_file: Optional[str] = None,
in_memory: bool = False,
):
db_file = get_db_file_in_memory(db_file, in_memory)

db = db or SQLiteDatabaseEngine.from_db_file(db_file)

super().__init__(db, table_prefix, skip_db_init)
Expand Down Expand Up @@ -382,6 +396,7 @@ def __init__(
partial_id: Optional[int] = None,
db: Optional["SQLiteDatabaseEngine"] = None,
db_file: Optional[str] = None,
in_memory: bool = False,
):
self.schema: DefaultSchema = DefaultSchema()
super().__init__(id_generator, uri, partial_id)
Expand All @@ -390,6 +405,8 @@ def __init__(
# foreign keys
self.default_table_names: list[str] = []

db_file = get_db_file_in_memory(db_file, in_memory)

self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)

self._init_tables()
Expand Down Expand Up @@ -554,10 +571,13 @@ def __init__(
id_generator: "SQLiteIDGenerator",
db: Optional["SQLiteDatabaseEngine"] = None,
db_file: Optional[str] = None,
in_memory: bool = False,
):
self.schema: DefaultSchema = DefaultSchema()
super().__init__(id_generator)

db_file = get_db_file_in_memory(db_file, in_memory)

self.db = db or SQLiteDatabaseEngine.from_db_file(db_file)

def __exit__(self, exc_type, exc_value, traceback) -> None:
Expand Down
31 changes: 25 additions & 6 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def from_storage(
*,
type: Literal["binary", "text", "image"] = "binary",
session: Optional[Session] = None,
in_memory: bool = False,
recursive: Optional[bool] = True,
object_name: str = "file",
update: bool = False,
Expand All @@ -332,7 +333,14 @@ def from_storage(
"""
func = get_file(type)
return (
cls(path, session=session, recursive=recursive, update=update, **kwargs)
cls(
path,
session=session,
recursive=recursive,
update=update,
in_memory=in_memory,
**kwargs,
)
.map(**{object_name: func})
.select(object_name)
)
Expand Down Expand Up @@ -479,7 +487,10 @@ def jmespath_to_name(s: str):

@classmethod
def datasets(
cls, session: Optional[Session] = None, object_name: str = "dataset"
cls,
session: Optional[Session] = None,
in_memory: bool = False,
object_name: str = "dataset",
) -> "DataChain":
"""Generate chain with list of registered datasets.
Expand All @@ -492,7 +503,7 @@ def datasets(
print(f"{ds.name}@v{ds.version}")
```
"""
session = Session.get(session)
session = Session.get(session, in_memory=in_memory)
catalog = session.catalog

datasets = [
Expand All @@ -502,6 +513,7 @@ def datasets(

return cls.from_values(
session=session,
in_memory=in_memory,
output={object_name: DatasetInfo},
**{object_name: datasets}, # type: ignore[arg-type]
)
Expand Down Expand Up @@ -1142,6 +1154,7 @@ def from_values(
cls,
ds_name: str = "",
session: Optional[Session] = None,
in_memory: bool = False,
output: OutputType = None,
object_name: str = "",
**fr_map,
Expand All @@ -1158,7 +1171,9 @@ def from_values(
def _func_fr() -> Iterator[tuple_type]: # type: ignore[valid-type]
yield from tuples

chain = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD, session=session)
chain = DataChain.from_records(
DataChain.DEFAULT_FILE_RECORD, session=session, in_memory=in_memory
)
if object_name:
output = {object_name: DataChain._dict_to_data_model(object_name, output)} # type: ignore[arg-type]
return chain.gen(_func_fr, output=output)
Expand All @@ -1169,6 +1184,7 @@ def from_pandas( # type: ignore[override]
df: "pd.DataFrame",
name: str = "",
session: Optional[Session] = None,
in_memory: bool = False,
object_name: str = "",
) -> "DataChain":
"""Generate chain from pandas data-frame.
Expand Down Expand Up @@ -1196,7 +1212,9 @@ def from_pandas( # type: ignore[override]
f"import from pandas error - '{column}' cannot be a column name",
)

return cls.from_values(name, session, object_name=object_name, **fr_map)
return cls.from_values(
name, session, object_name=object_name, in_memory=in_memory, **fr_map
)

def to_pandas(self, flatten=False) -> "pd.DataFrame":
"""Return a pandas DataFrame from the chain.
Expand Down Expand Up @@ -1505,6 +1523,7 @@ def from_records(
cls,
to_insert: Optional[Union[dict, list[dict]]],
session: Optional[Session] = None,
in_memory: bool = False,
) -> "DataChain":
"""Create a DataChain from the provided records. This method can be used for
programmatically generating a chain in contrast of reading data from storages
Expand All @@ -1520,7 +1539,7 @@ def from_records(
single_record = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD)
```
"""
session = Session.get(session)
session = Session.get(session, in_memory=in_memory)
catalog = session.catalog

name = session.generate_temp_dataset_name()
Expand Down
13 changes: 12 additions & 1 deletion src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,12 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:

try:
if workers:
if self.catalog.in_memory:
raise RuntimeError(
"In-memory databases cannot be used with "
"distributed processing."
)

from datachain.catalog.loader import get_distributed_class

distributor = get_distributed_class(min_task_size=self.min_task_size)
Expand All @@ -486,6 +492,10 @@ def populate_udf_table(self, udf_table: "Table", query: Select) -> None:
)
elif processes:
# Parallel processing (faster for more CPU-heavy UDFs)
if self.catalog.in_memory:
raise RuntimeError(
"In-memory databases cannot be used with parallel processing."
)
udf_info = {
"udf_data": filtered_cloudpickle_dumps(self.udf),
"catalog_init": self.catalog.get_init_params(),
Expand Down Expand Up @@ -1053,6 +1063,7 @@ def __init__(
indexing_feature_schema: Optional[dict] = None,
indexing_column_types: Optional[dict[str, Any]] = None,
update: Optional[bool] = False,
in_memory: bool = False,
):
if client_config is None:
client_config = {}
Expand All @@ -1061,7 +1072,7 @@ def __init__(
client_config["anon"] = True

self.session = Session.get(
session, catalog=catalog, client_config=client_config
session, catalog=catalog, client_config=client_config, in_memory=in_memory
)
self.catalog = catalog or self.session.catalog
self.steps: list[Step] = []
Expand Down
Loading

0 comments on commit bbcff85

Please sign in to comment.