diff --git a/src/datacustomcode/common_config.py b/src/datacustomcode/common_config.py new file mode 100644 index 0000000..a3bdbee --- /dev/null +++ b/src/datacustomcode/common_config.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +import os +from typing import Any + +from pydantic import ( + BaseModel, + ConfigDict, + Field, +) +import yaml + +DEFAULT_CONFIG_NAME = "config.yaml" + + +def default_config_file() -> str: + return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME) + + +class ForceableConfig(BaseModel): + force: bool = Field( + default=False, + description="If True, this takes precedence over parameters passed to the " + "initializer of the client", + ) + + +class BaseObjectConfig(ForceableConfig): + model_config = ConfigDict(validate_default=True, extra="forbid") + type_config_name: str = Field( + description="The config name of the object to create", + ) + options: dict[str, Any] = Field( + default_factory=dict, + description="Options passed to the constructor.", + ) + + +class BaseConfig(ABC, BaseModel): + @abstractmethod + def update(self, other: Any) -> "BaseConfig": ... + + def load(self, config_path: str) -> "BaseConfig": + """Load configuration from a YAML file and merge with existing config""" + with open(config_path, "r") as f: + config_data = yaml.safe_load(f) + + loaded_config = self.__class__.model_validate(config_data) + self.update(loaded_config) + return self diff --git a/src/datacustomcode/config.py b/src/datacustomcode/config.py index d8b22fb..820c512 100644 --- a/src/datacustomcode/config.py +++ b/src/datacustomcode/config.py @@ -14,7 +14,6 @@ # limitations under the License. from __future__ import annotations -import os from typing import ( TYPE_CHECKING, Any, @@ -26,12 +25,14 @@ cast, ) -from pydantic import ( - BaseModel, - ConfigDict, - Field, +from pydantic import Field + +from datacustomcode.common_config import ( + BaseConfig, + BaseObjectConfig, + ForceableConfig, + default_config_file, ) -import yaml # This lets all readers and writers to be findable via config from datacustomcode.io import * # noqa: F403 @@ -42,36 +43,15 @@ from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002 from datacustomcode.spark.base import BaseSparkSessionProvider -DEFAULT_CONFIG_NAME = "config.yaml" - - if TYPE_CHECKING: from pyspark.sql import SparkSession -class ForceableConfig(BaseModel): - force: bool = Field( - default=False, - description="If True, this takes precedence over parameters passed to the " - "initializer of the client.", - ) - - _T = TypeVar("_T", bound="BaseDataAccessLayer") -class AccessLayerObjectConfig(ForceableConfig, Generic[_T]): - model_config = ConfigDict(validate_default=True, extra="forbid") +class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]): type_base: ClassVar[Type[BaseDataAccessLayer]] = BaseDataAccessLayer - type_config_name: str = Field( - description="The config name of the object to create. " - "For metrics, this would might be 'ipmnormal'. For custom classes, you can " - "assign a name to a class variable `CONFIG_NAME` and reference it here.", - ) - options: dict[str, Any] = Field( - default_factory=dict, - description="Options passed to the constructor.", - ) def to_object(self, spark: SparkSession) -> _T: type_ = self.type_base.subclass_from_config_name(self.type_config_name) @@ -97,35 +77,25 @@ class SparkConfig(ForceableConfig): _PX = TypeVar("_PX", bound=BaseProxyAccessLayer) -class ProxyAccessLayerObjectConfig(ForceableConfig, Generic[_PX]): +class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]): """Config for proxy clients that take no constructor args (e.g. no spark).""" - model_config = ConfigDict(validate_default=True, extra="forbid") type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer - type_config_name: str = Field( - description="CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient').", - ) - options: dict[str, Any] = Field(default_factory=dict) def to_object(self) -> _PX: type_ = self.type_base.subclass_from_config_name(self.type_config_name) return cast(_PX, type_(**self.options)) -class SparkProviderConfig(ForceableConfig, Generic[_P]): - model_config = ConfigDict(validate_default=True, extra="forbid") +class SparkProviderConfig(BaseObjectConfig, Generic[_P]): type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider - type_config_name: str = Field( - description="CONFIG_NAME of the Spark session provider." - ) - options: dict[str, Any] = Field(default_factory=dict) def to_object(self) -> _P: type_ = self.type_base.subclass_from_config_name(self.type_config_name) return cast(_P, type_(**self.options)) -class ClientConfig(BaseModel): +class ClientConfig(BaseConfig): reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None @@ -163,31 +133,10 @@ def merge( ) return self - def load(self, config_path: str) -> ClientConfig: - """Load a config from a file and update this config with it. - Args: - config_path: The path to the config file - - Returns: - Self, with updated values from the loaded config. - """ - with open(config_path, "r") as f: - config_data = yaml.safe_load(f) - loaded_config = ClientConfig.model_validate(config_data) - - return self.update(loaded_config) - - -config = ClientConfig() """Global config object. This is the object that makes config accessible globally and globally mutable. """ - - -def _defaults() -> str: - return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME) - - -config.load(_defaults()) +config = ClientConfig() +config.load(default_config_file()) diff --git a/src/datacustomcode/config.yaml b/src/datacustomcode/config.yaml index bf21209..d58bc7f 100644 --- a/src/datacustomcode/config.yaml +++ b/src/datacustomcode/config.yaml @@ -23,3 +23,7 @@ proxy_config: type_config_name: LocalProxyClientProvider options: credentials_profile: default + +einstein_predictions_config: + type_config_name: DefaultEinsteinPredictions + options: {} diff --git a/src/datacustomcode/einstein_predictions/__init__.py b/src/datacustomcode/einstein_predictions/__init__.py new file mode 100644 index 0000000..4a4b388 --- /dev/null +++ b/src/datacustomcode/einstein_predictions/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datacustomcode.einstein_predictions.base import EinsteinPredictions +from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions + +__all__ = [ + "EinsteinPredictions", + "DefaultEinsteinPredictions", +] diff --git a/src/datacustomcode/einstein_predictions/base.py b/src/datacustomcode/einstein_predictions/base.py new file mode 100644 index 0000000..f00e2c7 --- /dev/null +++ b/src/datacustomcode/einstein_predictions/base.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from datacustomcode.einstein_predictions.types import ( + PredictionRequest, + PredictionResponse, +) +from datacustomcode.mixin import UserExtendableNamedConfigMixin + + +class EinsteinPredictions(ABC, UserExtendableNamedConfigMixin): + CONFIG_NAME: str + + def __init__(self, **kwargs): + pass + + @abstractmethod + def predict(self, request: PredictionRequest) -> PredictionResponse: ... diff --git a/src/datacustomcode/einstein_predictions/impl/default.py b/src/datacustomcode/einstein_predictions/impl/default.py new file mode 100644 index 0000000..ce8d65a --- /dev/null +++ b/src/datacustomcode/einstein_predictions/impl/default.py @@ -0,0 +1,35 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datacustomcode.einstein_predictions.base import EinsteinPredictions +from datacustomcode.einstein_predictions.types import ( + PredictionRequest, + PredictionResponse, +) + + +class DefaultEinsteinPredictions(EinsteinPredictions): + CONFIG_NAME = "DefaultEinsteinPredictions" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def predict(self, request: PredictionRequest) -> PredictionResponse: + return PredictionResponse( + version="v1", + prediction_type=request.prediction_type, + status_code=200, + data={"results": [{"prediction": {"predictedValue": "1"}}]}, + ) diff --git a/src/datacustomcode/einstein_predictions/types.py b/src/datacustomcode/einstein_predictions/types.py new file mode 100644 index 0000000..92a7bdd --- /dev/null +++ b/src/datacustomcode/einstein_predictions/types.py @@ -0,0 +1,184 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum, unique +from typing import ( + Any, + Dict, + Literal, + Optional, +) + +from pydantic import ( + BaseModel, + Field, + model_validator, +) + + +@unique +class PredictionType(Enum): + REGRESSION = 1 + CLUSTERING = 2 + CLASSIFICATION = 3 + MULTI_OUTCOME = 4 + BINARY_CLASSIFICATION = 5 + + +class PredictionColumn(BaseModel): + column_name: str = Field(min_length=1, description="Column name") + string_values: Optional[list[str]] = Field( + default=None, min_length=1, description="Column string values" + ) + double_values: Optional[list[float]] = Field( + default=None, min_length=1, description="Column double values" + ) + boolean_values: Optional[list[bool]] = Field( + default=None, min_length=1, description="Column boolean values" + ) + date_values: Optional[list[str]] = Field( + default=None, min_length=1, description="Column date values" + ) + datetime_values: Optional[list[str]] = Field( + default=None, min_length=1, description="Column datetime values" + ) + + @model_validator(mode="after") + def validate_exactly_one_value_type(self): + set_count = sum( + [ + self.string_values is not None, + self.double_values is not None, + self.boolean_values is not None, + self.date_values is not None, + self.datetime_values is not None, + ] + ) + + if set_count != 1: + raise ValueError("Exactly one value type must be set") + + return self + + +class PredictionColumBuilder: + def __init__(self) -> None: + self._column_name: Optional[str] = None + self._string_values: Optional[list[str]] = None + self._double_values: Optional[list[float]] = None + self._boolean_values: Optional[list[bool]] = None + self._date_values: Optional[list[str]] = None + self._datetime_values: Optional[list[str]] = None + + def set_column_name(self, column_name: str) -> "PredictionColumBuilder": + self._column_name = column_name + return self + + def set_string_values(self, string_values: list[str]) -> "PredictionColumBuilder": + self._string_values = string_values + return self + + def set_double_values(self, double_values: list[float]) -> "PredictionColumBuilder": + self._double_values = double_values + return self + + def set_boolean_values( + self, boolean_values: list[bool] + ) -> "PredictionColumBuilder": + self._boolean_values = boolean_values + return self + + def set_date_values(self, date_values: list[str]) -> "PredictionColumBuilder": + self._date_values = date_values + return self + + def set_datetime_values( + self, datetime_values: list[str] + ) -> "PredictionColumBuilder": + self._datetime_values = datetime_values + return self + + def build(self) -> PredictionColumn: + return PredictionColumn( + column_name=self._column_name, + string_values=self._string_values, + double_values=self._double_values, + boolean_values=self._boolean_values, + date_values=self._date_values, + datetime_values=self._datetime_values, + ) + + +class PredictionRequest(BaseModel): + version: Literal["v1"] = Field( + default="v1", description="API version, must be 'v1'" + ) + prediction_type: PredictionType = Field(description="Prediction type") + model_api_name: str = Field( + min_length=1, description="API name of the model to use" + ) + prediction_columns: list[PredictionColumn] = Field( + min_length=1, description="List of prediction columns" + ) + settings: Optional[Dict[str, Any]] = Field( + default=None, description="Settings for the prediction request" + ) + + +class PredictionRequestBuilder: + def __init__(self) -> None: + self._prediction_type: Optional[PredictionType] = None + self._model_api_name: Optional[str] = None + self._prediction_columns: list[PredictionColumn] = [] + self._settings: Optional[Dict[str, Any]] = None + + def set_prediction_type( + self, prediction_type: PredictionType + ) -> "PredictionRequestBuilder": + self._prediction_type = prediction_type + return self + + def set_model(self, model_api_name: str) -> "PredictionRequestBuilder": + self._model_api_name = model_api_name + return self + + def set_prediction_columns( + self, prediction_columns: list[PredictionColumn] + ) -> "PredictionRequestBuilder": + self._prediction_columns = prediction_columns + return self + + def set_settings(self, settings: Dict[str, Any]): + self._settings = settings + return self + + def build(self) -> PredictionRequest: + return PredictionRequest( + prediction_type=self._prediction_type, + model_api_name=self._model_api_name, + prediction_columns=self._prediction_columns, + settings=self._settings, + ) + + +class PredictionResponse(BaseModel): + version: Literal["v1"] = Field(default="v1", description="API version") + prediction_type: PredictionType = Field(description="Prediction type") + status_code: int = Field(description="HTTP status code") + data: Optional[Dict[str, Any]] = Field(default=None, description="Response data") + + @property + def is_success(self) -> bool: + return self.status_code == 200 diff --git a/src/datacustomcode/einstein_predictions_config.py b/src/datacustomcode/einstein_predictions_config.py new file mode 100644 index 0000000..4e83164 --- /dev/null +++ b/src/datacustomcode/einstein_predictions_config.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + ClassVar, + Generic, + Type, + TypeVar, + Union, + cast, +) + +from datacustomcode.common_config import ( + BaseConfig, + BaseObjectConfig, + default_config_file, +) +from datacustomcode.einstein_predictions.base import EinsteinPredictions + +_E = TypeVar("_E", bound=EinsteinPredictions) + + +class EinsteinPredictionsObjectConfig(BaseObjectConfig, Generic[_E]): + type_base: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract] + + def to_object(self) -> _E: + type_ = self.type_base.subclass_from_config_name(self.type_config_name) + return cast(_E, type_(**self.options)) + + +class EinsteinPredictionsConfig(BaseConfig): + einstein_predictions_config: Union[ + EinsteinPredictionsObjectConfig[EinsteinPredictions], None + ] = None + + def update(self, other: "EinsteinPredictionsConfig") -> "EinsteinPredictionsConfig": + def merge( + config_a: Union[EinsteinPredictionsObjectConfig, None], + config_b: Union[EinsteinPredictionsObjectConfig, None], + ) -> Union[EinsteinPredictionsObjectConfig, None]: + if config_a is not None and config_a.force: + return config_a + if config_b: + return config_b + return config_a + + self.einstein_predictions_config = merge( + self.einstein_predictions_config, other.einstein_predictions_config + ) + return self + + +# Global Einstein Predictions config instance +einstein_predictions_config = EinsteinPredictionsConfig() +einstein_predictions_config.load(default_config_file()) diff --git a/src/datacustomcode/function/runtime.py b/src/datacustomcode/function/runtime.py index 09c0433..ffda532 100644 --- a/src/datacustomcode/function/runtime.py +++ b/src/datacustomcode/function/runtime.py @@ -17,6 +17,8 @@ import threading from typing import Optional +from datacustomcode.einstein_predictions.base import EinsteinPredictions +from datacustomcode.einstein_predictions_config import einstein_predictions_config from datacustomcode.file.path.default import DefaultFindFilePath from datacustomcode.function.base import BaseRuntime from datacustomcode.llm_gateway.default import DefaultLLMGateway @@ -65,6 +67,7 @@ def __init__(self) -> None: # Initialize resources self._llm_gateway = DefaultLLMGateway() self._file = DefaultFindFilePath() + self._einstein_predictions: Optional[EinsteinPredictions] = None @property def llm_gateway(self) -> DefaultLLMGateway: @@ -75,3 +78,16 @@ def llm_gateway(self) -> DefaultLLMGateway: def file(self) -> DefaultFindFilePath: """Access file operations.""" return self._file + + @property + def einstein_predictions(self) -> EinsteinPredictions: + if self._einstein_predictions is None: + if einstein_predictions_config.einstein_predictions_config is None: + raise RuntimeError( + "Einstein Predictions is not configured. Add " + "'einstein_predictions_config' section to config.yaml" + ) + self._einstein_predictions = ( + einstein_predictions_config.einstein_predictions_config.to_object() + ) + return self._einstein_predictions diff --git a/tests/test_config.py b/tests/test_config.py index 463dc95..0adbbfa 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,8 +11,8 @@ AccessLayerObjectConfig, ClientConfig, SparkConfig, - _defaults, config, + default_config_file, ) from datacustomcode.io.base import BaseDataAccessLayer from datacustomcode.io.reader.base import BaseDataCloudReader @@ -298,8 +298,8 @@ def test_load_config_from_file(self): os.unlink(temp_path) def test_defaults(self): - # Just verify that _defaults function exists and returns a string path - result = _defaults() + # Just verify that default_config_file function exists and returns a string path + result = default_config_file() assert isinstance(result, str) assert result.endswith("config.yaml") diff --git a/tests/test_einstein_predictions.py b/tests/test_einstein_predictions.py new file mode 100644 index 0000000..e760abc --- /dev/null +++ b/tests/test_einstein_predictions.py @@ -0,0 +1,181 @@ +from pydantic import ValidationError +import pytest + +from datacustomcode.einstein_predictions.types import ( + PredictionColumBuilder, + PredictionColumn, + PredictionRequest, + PredictionRequestBuilder, + PredictionResponse, + PredictionType, +) + + +class TestPredictionColumnValidation: + def test_string_values_only(self): + column = PredictionColumn(column_name="test_col", string_values=["a", "b", "c"]) + assert column.column_name == "test_col" + assert column.string_values == ["a", "b", "c"] + assert column.double_values is None + assert column.boolean_values is None + assert column.date_values is None + assert column.datetime_values is None + + def test_double_values_only(self): + column = PredictionColumn(column_name="test_col", double_values=[1.0, 2.5, 3.7]) + assert column.double_values == [1.0, 2.5, 3.7] + assert column.string_values is None + assert column.boolean_values is None + assert column.date_values is None + assert column.datetime_values is None + + def test_boolean_values_only(self): + column = PredictionColumn( + column_name="test_col", boolean_values=[True, False, True] + ) + assert column.boolean_values == [True, False, True] + assert column.string_values is None + assert column.double_values is None + assert column.date_values is None + assert column.datetime_values is None + + def test_date_values_only(self): + column = PredictionColumn( + column_name="test_col", date_values=["2024-01-01", "2024-01-02"] + ) + assert column.date_values == ["2024-01-01", "2024-01-02"] + assert column.string_values is None + assert column.double_values is None + assert column.boolean_values is None + assert column.datetime_values is None + + def test_datetime_values_only(self): + column = PredictionColumn( + column_name="test_col", + datetime_values=["2024-01-01T12:00:00", "2024-01-02T13:00:00"], + ) + assert column.datetime_values == ["2024-01-01T12:00:00", "2024-01-02T13:00:00"] + assert column.string_values is None + assert column.double_values is None + assert column.boolean_values is None + assert column.date_values is None + + def test_no_column_name_raises_error(self): + with pytest.raises(ValidationError) as exc_info: + PredictionColumn( + column_name="", string_values=["a", "b"], double_values=[1.0, 2.0] + ) + + assert str(exc_info.value) is not None + + def test_no_values_raises_error(self): + with pytest.raises(ValidationError) as exc_info: + PredictionColumn(column_name="test_col") + + assert str(exc_info.value) is not None + + def test_string_and_double_raises_error(self): + with pytest.raises(ValidationError) as exc_info: + PredictionColumn( + column_name="test_col", + string_values=["a", "b"], + double_values=[1.0, 2.0], + ) + + assert str(exc_info.value) is not None + + def test_empty_values_raises_error(self): + with pytest.raises(ValidationError) as exc_info: + PredictionColumn(column_name="test_col", string_values=[]) + + assert str(exc_info.value) is not None + + +class TestPredictionColumnBuilder: + def test_builder_with_string_values(self): + column = ( + PredictionColumBuilder() + .set_column_name("test_col") + .set_string_values(["a", "b"]) + .build() + ) + + assert column.column_name == "test_col" + assert column.string_values == ["a", "b"] + + +class TestPredictionRequest: + def test_request_with_multiple_columns(self): + request = PredictionRequest( + prediction_type=PredictionType.CLASSIFICATION, + model_api_name="classifier", + prediction_columns=[ + PredictionColumn(column_name="col1", string_values=["a"]), + PredictionColumn(column_name="col2", double_values=[1.0]), + PredictionColumn(column_name="col3", boolean_values=[True]), + ], + ) + + assert len(request.prediction_columns) == 3 + + def test_request_requires_model_api_name(self): + with pytest.raises(ValidationError): + PredictionRequest( + prediction_type=PredictionType.REGRESSION, + model_api_name="", + prediction_columns=[ + PredictionColumn(column_name="col1", double_values=[1.0]) + ], + ) + + def test_request_requires_prediction_columns(self): + with pytest.raises(ValidationError): + PredictionRequest( + prediction_type=PredictionType.REGRESSION, + model_api_name="model", + prediction_columns=[], + ) + + +class TestPredictionRequestBuilder: + def test_builder_creates_valid_request(self): + request = ( + PredictionRequestBuilder() + .set_prediction_type(PredictionType.CLUSTERING) + .set_model("cluster_model") + .set_prediction_columns( + [PredictionColumn(column_name="test_col", double_values=[1.0])] + ) + .set_settings({"maxTopContributors": 20}) + .build() + ) + + assert request.prediction_type == PredictionType.CLUSTERING + assert request.model_api_name == "cluster_model" + assert len(request.prediction_columns) == 1 + assert request.settings == {"maxTopContributors": 20} + + +class TestPredictionResponse: + def test_successful_response(self): + response = PredictionResponse( + version="v1", + prediction_type=PredictionType.REGRESSION, + status_code=200, + data={"results": [{"prediction": {"value": 42.5}}]}, + ) + + assert response.is_success + assert response.status_code == 200 + assert response.data is not None + + def test_failed_response(self): + response = PredictionResponse( + version="v1", + prediction_type=PredictionType.REGRESSION, + status_code=500, + data={"error": "Internal server error"}, + ) + + assert not response.is_success + assert response.status_code == 500 diff --git a/tests/test_einstein_predictions_config_update.py b/tests/test_einstein_predictions_config_update.py new file mode 100644 index 0000000..12bdc62 --- /dev/null +++ b/tests/test_einstein_predictions_config_update.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import yaml + +from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions +from datacustomcode.einstein_predictions_config import ( + EinsteinPredictionsConfig, + EinsteinPredictionsObjectConfig, +) + + +class TestEinsteinPredictionsConfigUpdate: + def test_update_replaces_config_without_force(self): + config1 = EinsteinPredictionsConfig( + einstein_predictions_config=EinsteinPredictionsObjectConfig( + type_config_name="OldImplementation", options={"old": True} + ) + ) + + config2 = EinsteinPredictionsConfig( + einstein_predictions_config=EinsteinPredictionsObjectConfig( + type_config_name="NewImplementation", options={"new": True} + ) + ) + + config1.update(config2) + + assert ( + config1.einstein_predictions_config.type_config_name == "NewImplementation" + ) + assert config1.einstein_predictions_config.options == {"new": True} + + def test_update_respects_force_flag(self): + config1 = EinsteinPredictionsConfig( + einstein_predictions_config=EinsteinPredictionsObjectConfig( + type_config_name="ForcedImplementation", + options={"forced": True}, + force=True, + ) + ) + + config2 = EinsteinPredictionsConfig( + einstein_predictions_config=EinsteinPredictionsObjectConfig( + type_config_name="NewImplementation", options={"new": True} + ) + ) + + config1.update(config2) + + assert ( + config1.einstein_predictions_config.type_config_name + == "ForcedImplementation" + ) + assert config1.einstein_predictions_config.options == {"forced": True} + assert config1.einstein_predictions_config.force is True + + +class TestEinsteinPredictionsConfigLoad: + def test_load_from_yaml_file(self): + config_data = { + "einstein_predictions_config": { + "type_config_name": "DefaultEinsteinPredictions" + } + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(config_data, f) + temp_file = f.name + + try: + config = EinsteinPredictionsConfig() + config.load(temp_file) + + assert config.einstein_predictions_config is not None + assert ( + config.einstein_predictions_config.type_config_name + == "DefaultEinsteinPredictions" + ) + einstein_predictions = config.einstein_predictions_config.to_object() + assert einstein_predictions is not None + assert isinstance(einstein_predictions, DefaultEinsteinPredictions) + finally: + os.unlink(temp_file) diff --git a/tests/test_runtime.py b/tests/test_runtime.py index e888d11..a8e1879 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -1,9 +1,8 @@ -from __future__ import annotations - import threading import pytest +from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions from datacustomcode.file.path.default import DefaultFindFilePath from datacustomcode.function.runtime import Runtime from datacustomcode.llm_gateway.default import DefaultLLMGateway @@ -72,6 +71,8 @@ def test_runtime_has_llm_gateway(self): """Test Runtime has llm_gateway property.""" assert hasattr(self.runtime, "llm_gateway") assert isinstance(self.runtime.llm_gateway, DefaultLLMGateway) + assert hasattr(self.runtime, "einstein_predictions") + assert isinstance(self.runtime.einstein_predictions, DefaultEinsteinPredictions) def test_runtime_has_file(self): """Test Runtime has file property.""" diff --git a/tests/test_runtime_einstein_predictions.py b/tests/test_runtime_einstein_predictions.py new file mode 100644 index 0000000..cd05d03 --- /dev/null +++ b/tests/test_runtime_einstein_predictions.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datacustomcode.einstein_predictions.base import EinsteinPredictions +from datacustomcode.einstein_predictions.types import ( + PredictionColumn, + PredictionRequest, + PredictionResponse, + PredictionType, +) +from datacustomcode.einstein_predictions_config import EinsteinPredictionsObjectConfig + + +class TestCustomEinsteinPredictionsImplementation: + """Test that other implementations are supported""" + + def test_custom_implementation_is_discoverable(self): + class CustomEinsteinPredictions(EinsteinPredictions): + CONFIG_NAME = "CustomEinsteinPredictions" + + def __init__(self, custom_param: str = "default", **kwargs): + super().__init__(**kwargs) + self.custom_param = custom_param + + def predict(self, request: PredictionRequest) -> PredictionResponse: + return PredictionResponse( + version="v1", + prediction_type=request.prediction_type, + status_code=200, + data={"results": [{"predictedValue": 1}]}, + ) + + available_names = EinsteinPredictions.available_config_names() + assert "CustomEinsteinPredictions" in available_names + + cls = EinsteinPredictions.subclass_from_config_name("CustomEinsteinPredictions") + assert cls == CustomEinsteinPredictions + + # Verify we can create via config + ep_config = EinsteinPredictionsObjectConfig( + type_config_name="CustomEinsteinPredictions", + options={"custom_param": "my_value"}, + ) + instance = ep_config.to_object() + assert isinstance(instance, CustomEinsteinPredictions) + assert instance.custom_param == "my_value" + + request = PredictionRequest( + prediction_type=PredictionType.REGRESSION, + model_api_name="test", + prediction_columns=[ + PredictionColumn(column_name="col1", double_values=[1.0]) + ], + ) + response = instance.predict(request) + assert response.is_success is True + assert response.data["results"] == [{"predictedValue": 1}]