From 44f1f7257b4815aba79d07011257910505409e3e Mon Sep 17 00:00:00 2001 From: Alex Tulikumwenayo Date: Wed, 22 Apr 2026 11:01:59 -0400 Subject: [PATCH 1/4] add support for einstein predict --- .../einstein_predictions/base.py | 22 ++ .../einstein_predictions/impl/default.py | 27 +++ .../einstein_predictions/types.py | 156 +++++++++++++++ tests/test_einstein_predictions.py | 189 ++++++++++++++++++ 4 files changed, 394 insertions(+) create mode 100644 src/datacustomcode/einstein_predictions/base.py create mode 100644 src/datacustomcode/einstein_predictions/impl/default.py create mode 100644 src/datacustomcode/einstein_predictions/types.py create mode 100644 tests/test_einstein_predictions.py diff --git a/src/datacustomcode/einstein_predictions/base.py b/src/datacustomcode/einstein_predictions/base.py new file mode 100644 index 0000000..c58d71d --- /dev/null +++ b/src/datacustomcode/einstein_predictions/base.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from datacustomcode.einstein_predictions.types import (PredictionRequest, PredictionResponse) + +class EinsteinPredictions(ABC): + @abstractmethod + def predict(self, request: PredictionRequest) -> PredictionResponse: ... diff --git a/src/datacustomcode/einstein_predictions/impl/default.py b/src/datacustomcode/einstein_predictions/impl/default.py new file mode 100644 index 0000000..5b4c6fb --- /dev/null +++ b/src/datacustomcode/einstein_predictions/impl/default.py @@ -0,0 +1,27 @@ + +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datacustomcode.einstein_predictions.types import ( + PredictionRequest, + PredictionResponse +) + +class DefaultEinsteinPredictions: + def __init__(self, base_url: str, access_token: str) -> None: + pass + + def predict(self, request: PredictionRequest) -> PredictionResponse: + pass \ No newline at end of file diff --git a/src/datacustomcode/einstein_predictions/types.py b/src/datacustomcode/einstein_predictions/types.py new file mode 100644 index 0000000..5faaee2 --- /dev/null +++ b/src/datacustomcode/einstein_predictions/types.py @@ -0,0 +1,156 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum, unique +from pydantic import ( + BaseModel, + Field, + model_validator, +) + +from typing import ( + Literal, + Optional, + Dict, + Any +) + +@unique +class PredictionType(Enum): + REGRESSION = 1 + CLUSTERING = 2 + CLASSIFICATION = 3 + MULTI_OUTCOME = 4 + BINARY_CLASSIFICATION = 5 + +class PredictionColumn(BaseModel): + column_name: str = Field(min_length=1, description="Column name") + string_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column string values") + double_values: Optional[list[float]] = Field(default = None, min_length=1, description="Column double values") + boolean_values: Optional[list[bool]] = Field(default = None, min_length=1, description="Column boolean values") + date_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column date values") + datetime_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column datetime values") + + @model_validator(mode='after') + def validate_exactly_one_value_type(self): + set_count = sum([ + self.string_values is not None, + self.double_values is not None, + self.boolean_values is not None, + self.date_values is not None, + self.datetime_values is not None + ]) + + if set_count != 1: + raise ValueError("Exactly one value type must be set") + + return self + +class PredictionColumBuilder: + def __init__(self) -> None: + self._column_name: str = None + self._string_values: list[str] = None + self._double_values: list[float] = None + self._boolean_values: list[bool] = None + self._date_values: list[str] = None + self._datetime_values: list[str] = None + + def set_column_name(self, column_name: str) -> "PredictionColumBuilder": + self._column_name = column_name + return self + + def set_string_values(self, string_values: list[str]) -> "PredictionColumBuilder": + self._string_values = string_values + return self + + def set_double_values(self, double_values: list[float]) -> "PredictionColumBuilder": + self._double_values = double_values + return self + + def set_boolean_values(self, boolean_values: list[bool]) -> "PredictionColumBuilder": + self._boolean_values = boolean_values + return self + + def set_date_values(self, date_values: list[str]) -> "PredictionColumBuilder": + self._date_values = date_values + return self + + def set_datetime_values(self, datetime_values: list[str]) -> "PredictionColumBuilder": + self._datetime_values = datetime_values + return self + + def build(self) -> PredictionColumn: + return PredictionColumn( + column_name = self._column_name, + string_values = self._string_values, + double_values = self._double_values, + boolean_values = self._boolean_values, + date_values = self._date_values, + datetime_values = self._datetime_values + ) + +class PredictionRequest(BaseModel): + version: Literal["v1"] = Field( + default="v1", description="API version, must be 'v1'" + ) + prediction_type: PredictionType = Field(description="Prediction type") + model_api_name: str = Field(min_length=1, description="API name of the model to use") + prediction_columns: list[PredictionColumn] = Field(min_length=1, description="List of prediction columns") + settings: Optional[Dict[str, Any]] = Field(default=None, description="Settings for the prediction request") + +class PredictionRequestBuilder: + def __init__(self) -> None: + self._prediction_type: PredictionType = None + self._model_api_name: str = None + self._prediction_columns: list[PredictionColumn] = [] + self._settings: Dict[str, Any] = None + + def set_prediction_type(self, prediction_type: PredictionType) -> "PredictionRequestBuilder": + self._prediction_type = prediction_type + return self + + def set_model(self, model_api_name: str) -> "PredictionRequestBuilder": + self._model_api_name = model_api_name + return self + + def set_prediction_columns( + self, + prediction_columns: list[PredictionColumn] + ) -> "PredictionRequestBuilder": + self._prediction_columns = prediction_columns + return self + + def set_settings(self, settings: Dict[str, Any]): + self._settings = settings + return self + + def build(self) -> PredictionRequest: + return PredictionRequest( + prediction_type=self._prediction_type, + model_api_name=self._model_api_name, + prediction_columns=self._prediction_columns, + settings=self._settings + ) + +class PredictionResponse(BaseModel): + version: Literal["v1"] = Field(default="v1", description="API version") + prediction_type: PredictionType = Field(description="Prediction type") + status_code: int = Field(description="HTTP status code") + data: Optional[Dict[str, Any]] = Field(default=None, description="Response data") + + @property + def is_success(self) -> bool: + return self.status_code == 200 + diff --git a/tests/test_einstein_predictions.py b/tests/test_einstein_predictions.py new file mode 100644 index 0000000..359353d --- /dev/null +++ b/tests/test_einstein_predictions.py @@ -0,0 +1,189 @@ +import pytest +from pydantic import ValidationError + +from datacustomcode.einstein_predictions.types import ( + PredictionColumn, + PredictionRequest, + PredictionResponse, + PredictionType, + PredictionColumBuilder, + PredictionRequestBuilder, +) + +class TestPredictionColumnValidation: + def test_string_values_only(self): + column = PredictionColumn( + column_name="test_col", + string_values=["a", "b", "c"] + ) + assert column.column_name == "test_col" + assert column.string_values == ["a", "b", "c"] + assert column.double_values is None + assert column.boolean_values is None + assert column.date_values is None + assert column.datetime_values is None + + def test_double_values_only(self): + column = PredictionColumn( + column_name="test_col", + double_values=[1.0, 2.5, 3.7] + ) + assert column.double_values == [1.0, 2.5, 3.7] + assert column.string_values is None + assert column.boolean_values is None + assert column.date_values is None + assert column.datetime_values is None + + def test_boolean_values_only(self): + column = PredictionColumn( + column_name="test_col", + boolean_values=[True, False, True] + ) + assert column.boolean_values == [True, False, True] + assert column.string_values is None + assert column.double_values is None + assert column.date_values is None + assert column.datetime_values is None + + + def test_date_values_only(self): + column = PredictionColumn( + column_name="test_col", + date_values=["2024-01-01", "2024-01-02"] + ) + assert column.date_values == ["2024-01-01", "2024-01-02"] + assert column.string_values is None + assert column.double_values is None + assert column.boolean_values is None + assert column.datetime_values is None + + def test_datetime_values_only(self): + column = PredictionColumn( + column_name="test_col", + datetime_values=["2024-01-01T12:00:00", "2024-01-02T13:00:00"] + ) + assert column.datetime_values == ["2024-01-01T12:00:00", "2024-01-02T13:00:00"] + assert column.string_values is None + assert column.double_values is None + assert column.boolean_values is None + assert column.date_values is None + + def test_no_column_name_raises_error(self): + with pytest.raises(ValidationError) as exc_info: + PredictionColumn( + column_name="", + string_values=["a", "b"], + double_values=[1.0, 2.0] + ) + + assert str(exc_info.value) is not None + + def test_no_values_raises_error(self): + with pytest.raises(ValidationError) as exc_info: + PredictionColumn(column_name="test_col") + + assert str(exc_info.value) is not None + + def test_string_and_double_raises_error(self): + with pytest.raises(ValidationError) as exc_info: + PredictionColumn( + column_name="test_col", + string_values=["a", "b"], + double_values=[1.0, 2.0] + ) + + assert str(exc_info.value) is not None + + def test_empty_values_raises_error(self): + with pytest.raises(ValidationError) as exc_info: + PredictionColumn( + column_name="test_col", + string_values=[] + ) + + assert str(exc_info.value) is not None + +class TestPredictionColumnBuilder: + def test_builder_with_string_values(self): + column = (PredictionColumBuilder() + .set_column_name("test_col") + .set_string_values(["a", "b"]) + .build()) + + assert column.column_name == "test_col" + assert column.string_values == ["a", "b"] + +class TestPredictionRequest: + def test_request_with_multiple_columns(self): + request = PredictionRequest( + prediction_type=PredictionType.CLASSIFICATION, + model_api_name="classifier", + prediction_columns=[ + PredictionColumn(column_name="col1", string_values=["a"]), + PredictionColumn(column_name="col2", double_values=[1.0]), + PredictionColumn(column_name="col3", boolean_values=[True]) + ] + ) + + assert len(request.prediction_columns) == 3 + + def test_request_requires_model_api_name(self): + with pytest.raises(ValidationError): + PredictionRequest( + prediction_type=PredictionType.REGRESSION, + model_api_name="", + prediction_columns=[ + PredictionColumn(column_name="col1", double_values=[1.0]) + ] + ) + + def test_request_requires_prediction_columns(self): + with pytest.raises(ValidationError): + PredictionRequest( + prediction_type=PredictionType.REGRESSION, + model_api_name="model", + prediction_columns=[] + ) + +class TestPredictionRequestBuilder: + def test_builder_creates_valid_request(self): + request = (PredictionRequestBuilder() + .set_prediction_type(PredictionType.CLUSTERING) + .set_model("cluster_model") + .set_prediction_columns([ + PredictionColumn(column_name="test_col", double_values=[1.0]) + ]) + .set_settings({ + 'maxTopContributors': 20 + }) + .build()) + + assert request.prediction_type == PredictionType.CLUSTERING + assert request.model_api_name == "cluster_model" + assert len(request.prediction_columns) == 1 + assert request.settings == { 'maxTopContributors': 20 } + + +class TestPredictionResponse: + def test_successful_response(self): + response = PredictionResponse( + version="v1", + prediction_type=PredictionType.REGRESSION, + status_code=200, + data={"results": [{"prediction": {"value": 42.5}}]} + ) + + assert response.is_success + assert response.status_code == 200 + assert response.data is not None + + def test_failed_response(self): + response = PredictionResponse( + version="v1", + prediction_type=PredictionType.REGRESSION, + status_code=500, + data={"error": "Internal server error"} + ) + + assert not response.is_success + assert response.status_code == 500 \ No newline at end of file From 54d81a69a36c3717c66267c3af90989ba650cbed Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Wed, 22 Apr 2026 17:18:15 -0400 Subject: [PATCH 2/4] add support for einstein predict --- src/datacustomcode/common_config.py | 44 +++++++++ src/datacustomcode/config.py | 73 ++------------- src/datacustomcode/config.yaml | 4 + .../einstein_predictions/__init__.py | 22 +++++ .../einstein_predictions/base.py | 8 +- .../einstein_predictions/impl/default.py | 16 +++- .../einstein_predictions_config.py | 57 ++++++++++++ src/datacustomcode/function/runtime.py | 13 +++ tests/test_config.py | 6 +- ...test_einstein_predictions_config_update.py | 91 +++++++++++++++++++ tests/test_runtime.py | 5 +- tests/test_runtime_einstein_predictions.py | 71 +++++++++++++++ 12 files changed, 334 insertions(+), 76 deletions(-) create mode 100644 src/datacustomcode/common_config.py create mode 100644 src/datacustomcode/einstein_predictions/__init__.py create mode 100644 src/datacustomcode/einstein_predictions_config.py create mode 100644 tests/test_einstein_predictions_config_update.py create mode 100644 tests/test_runtime_einstein_predictions.py diff --git a/src/datacustomcode/common_config.py b/src/datacustomcode/common_config.py new file mode 100644 index 0000000..5da15bd --- /dev/null +++ b/src/datacustomcode/common_config.py @@ -0,0 +1,44 @@ +import os +import yaml +from pydantic import ( + BaseModel, + ConfigDict, + Field, +) + +from typing import Any + +DEFAULT_CONFIG_NAME = "config.yaml" + + +def default_config_file() -> str: + return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME) + + +class ForceableConfig(BaseModel): + force: bool = Field( + default=False, + description="If True, this takes precedence over parameters passed to the initializer of the client", + ) + + +class BaseObjectConfig(ForceableConfig): + model_config = ConfigDict(validate_default=True, extra="forbid") + type_config_name: str = Field( + description="The config name of the object to create", + ) + options: dict[str, Any] = Field( + default_factory=dict, + description="Options passed to the constructor.", + ) + + +class BaseConfig(BaseModel): + def load(self, config_path: str) -> "BaseConfig": + """Load configuration from a YAML file and merge with existing config""" + with open(config_path, "r") as f: + config_data = yaml.safe_load(f) + + loaded_config = self.__class__.model_validate(config_data) + self.update(loaded_config) + return self diff --git a/src/datacustomcode/config.py b/src/datacustomcode/config.py index d8b22fb..acd1fed 100644 --- a/src/datacustomcode/config.py +++ b/src/datacustomcode/config.py @@ -14,7 +14,6 @@ # limitations under the License. from __future__ import annotations -import os from typing import ( TYPE_CHECKING, Any, @@ -27,11 +26,8 @@ ) from pydantic import ( - BaseModel, - ConfigDict, Field, ) -import yaml # This lets all readers and writers to be findable via config from datacustomcode.io import * # noqa: F403 @@ -41,38 +37,18 @@ from datacustomcode.proxy.base import BaseProxyAccessLayer from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002 from datacustomcode.spark.base import BaseSparkSessionProvider - -DEFAULT_CONFIG_NAME = "config.yaml" +from datacustomcode.common_config import ForceableConfig, BaseObjectConfig, BaseConfig, default_config_file if TYPE_CHECKING: from pyspark.sql import SparkSession -class ForceableConfig(BaseModel): - force: bool = Field( - default=False, - description="If True, this takes precedence over parameters passed to the " - "initializer of the client.", - ) - - _T = TypeVar("_T", bound="BaseDataAccessLayer") -class AccessLayerObjectConfig(ForceableConfig, Generic[_T]): - model_config = ConfigDict(validate_default=True, extra="forbid") +class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]): type_base: ClassVar[Type[BaseDataAccessLayer]] = BaseDataAccessLayer - type_config_name: str = Field( - description="The config name of the object to create. " - "For metrics, this would might be 'ipmnormal'. For custom classes, you can " - "assign a name to a class variable `CONFIG_NAME` and reference it here.", - ) - options: dict[str, Any] = Field( - default_factory=dict, - description="Options passed to the constructor.", - ) - def to_object(self, spark: SparkSession) -> _T: type_ = self.type_base.subclass_from_config_name(self.type_config_name) return cast(_T, type_(spark=spark, **self.options)) @@ -97,35 +73,22 @@ class SparkConfig(ForceableConfig): _PX = TypeVar("_PX", bound=BaseProxyAccessLayer) -class ProxyAccessLayerObjectConfig(ForceableConfig, Generic[_PX]): +class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]): """Config for proxy clients that take no constructor args (e.g. no spark).""" - - model_config = ConfigDict(validate_default=True, extra="forbid") type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer - type_config_name: str = Field( - description="CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient').", - ) - options: dict[str, Any] = Field(default_factory=dict) - def to_object(self) -> _PX: type_ = self.type_base.subclass_from_config_name(self.type_config_name) return cast(_PX, type_(**self.options)) -class SparkProviderConfig(ForceableConfig, Generic[_P]): - model_config = ConfigDict(validate_default=True, extra="forbid") +class SparkProviderConfig(BaseObjectConfig, Generic[_P]): type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider - type_config_name: str = Field( - description="CONFIG_NAME of the Spark session provider." - ) - options: dict[str, Any] = Field(default_factory=dict) - def to_object(self) -> _P: type_ = self.type_base.subclass_from_config_name(self.type_config_name) return cast(_P, type_(**self.options)) -class ClientConfig(BaseModel): +class ClientConfig(BaseConfig): reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None @@ -163,31 +126,9 @@ def merge( ) return self - def load(self, config_path: str) -> ClientConfig: - """Load a config from a file and update this config with it. - - Args: - config_path: The path to the config file - - Returns: - Self, with updated values from the loaded config. - """ - with open(config_path, "r") as f: - config_data = yaml.safe_load(f) - loaded_config = ClientConfig.model_validate(config_data) - - return self.update(loaded_config) - - -config = ClientConfig() """Global config object. This is the object that makes config accessible globally and globally mutable. """ - - -def _defaults() -> str: - return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME) - - -config.load(_defaults()) +config = ClientConfig() +config.load(default_config_file()) diff --git a/src/datacustomcode/config.yaml b/src/datacustomcode/config.yaml index bf21209..d58bc7f 100644 --- a/src/datacustomcode/config.yaml +++ b/src/datacustomcode/config.yaml @@ -23,3 +23,7 @@ proxy_config: type_config_name: LocalProxyClientProvider options: credentials_profile: default + +einstein_predictions_config: + type_config_name: DefaultEinsteinPredictions + options: {} diff --git a/src/datacustomcode/einstein_predictions/__init__.py b/src/datacustomcode/einstein_predictions/__init__.py new file mode 100644 index 0000000..a9d02c9 --- /dev/null +++ b/src/datacustomcode/einstein_predictions/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datacustomcode.einstein_predictions.base import EinsteinPredictions +from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions + +__all__ = [ + "EinsteinPredictions", + "DefaultEinsteinPredictions", +] \ No newline at end of file diff --git a/src/datacustomcode/einstein_predictions/base.py b/src/datacustomcode/einstein_predictions/base.py index c58d71d..f0ccee0 100644 --- a/src/datacustomcode/einstein_predictions/base.py +++ b/src/datacustomcode/einstein_predictions/base.py @@ -16,7 +16,13 @@ from abc import ABC, abstractmethod from datacustomcode.einstein_predictions.types import (PredictionRequest, PredictionResponse) +from datacustomcode.mixin import UserExtendableNamedConfigMixin + +class EinsteinPredictions(ABC, UserExtendableNamedConfigMixin): + CONFIG_NAME: str + + def __init__(self, **kwargs): + pass -class EinsteinPredictions(ABC): @abstractmethod def predict(self, request: PredictionRequest) -> PredictionResponse: ... diff --git a/src/datacustomcode/einstein_predictions/impl/default.py b/src/datacustomcode/einstein_predictions/impl/default.py index 5b4c6fb..da6a8bd 100644 --- a/src/datacustomcode/einstein_predictions/impl/default.py +++ b/src/datacustomcode/einstein_predictions/impl/default.py @@ -14,14 +14,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +from datacustomcode.einstein_predictions.base import EinsteinPredictions from datacustomcode.einstein_predictions.types import ( PredictionRequest, PredictionResponse ) -class DefaultEinsteinPredictions: - def __init__(self, base_url: str, access_token: str) -> None: - pass +class DefaultEinsteinPredictions(EinsteinPredictions): + CONFIG_NAME = "DefaultEinsteinPredictions" + + def __init__(self, **kwargs): + super().__init__(**kwargs) def predict(self, request: PredictionRequest) -> PredictionResponse: - pass \ No newline at end of file + return PredictionResponse( + version="v1", + prediction_type=request.prediction_type, + status_code=200, + data={"results": [{"prediction": {"predictedValue": "1"}}]} + ) \ No newline at end of file diff --git a/src/datacustomcode/einstein_predictions_config.py b/src/datacustomcode/einstein_predictions_config.py new file mode 100644 index 0000000..3a713fe --- /dev/null +++ b/src/datacustomcode/einstein_predictions_config.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + ClassVar, + Generic, + Type, + TypeVar, + Union, + cast, +) + +from datacustomcode.einstein_predictions.base import EinsteinPredictions +from datacustomcode.common_config import BaseObjectConfig, BaseConfig, default_config_file + +_E = TypeVar("_E", bound=EinsteinPredictions) + +class EinsteinPredictionsObjectConfig(BaseObjectConfig, Generic[_E]): + type_base: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions + def to_object(self) -> _E: + type_ = self.type_base.subclass_from_config_name(self.type_config_name) + return cast(_E, type_(**self.options)) + + +class EinsteinPredictionsConfig(BaseConfig): + einstein_predictions_config: Union[EinsteinPredictionsObjectConfig[EinsteinPredictions], None] = None + def update(self, other: "EinsteinPredictionsConfig") -> "EinsteinPredictionsConfig": + def merge( + config_a: Union[EinsteinPredictionsObjectConfig, None], + config_b: Union[EinsteinPredictionsObjectConfig, None] + ) -> Union[EinsteinPredictionsObjectConfig, None]: + if config_a is not None and config_a.force: + return config_a + if config_b: + return config_b + return config_a + + self.einstein_predictions_config = merge( + self.einstein_predictions_config, other.einstein_predictions_config + ) + return self + +# Global Einstein Predictions config instance +einstein_predictions_config = EinsteinPredictionsConfig() +einstein_predictions_config.load(default_config_file()) diff --git a/src/datacustomcode/function/runtime.py b/src/datacustomcode/function/runtime.py index 09c0433..f1999e0 100644 --- a/src/datacustomcode/function/runtime.py +++ b/src/datacustomcode/function/runtime.py @@ -17,6 +17,8 @@ import threading from typing import Optional +from datacustomcode.einstein_predictions.base import EinsteinPredictions +from datacustomcode.einstein_predictions_config import einstein_predictions_config from datacustomcode.file.path.default import DefaultFindFilePath from datacustomcode.function.base import BaseRuntime from datacustomcode.llm_gateway.default import DefaultLLMGateway @@ -65,6 +67,7 @@ def __init__(self) -> None: # Initialize resources self._llm_gateway = DefaultLLMGateway() self._file = DefaultFindFilePath() + self._einstein_predictions: Optional[EinsteinPredictions] = None @property def llm_gateway(self) -> DefaultLLMGateway: @@ -75,3 +78,13 @@ def llm_gateway(self) -> DefaultLLMGateway: def file(self) -> DefaultFindFilePath: """Access file operations.""" return self._file + + @property + def einstein_predictions(self) -> EinsteinPredictions: + if self._einstein_predictions is None: + if einstein_predictions_config.einstein_predictions_config is None: + raise RuntimeError( + "Einstein Predictions is not configured.Add 'einstein_predictions_config' section to config.yaml" + ) + self._einstein_predictions = einstein_predictions_config.einstein_predictions_config.to_object() + return self._einstein_predictions diff --git a/tests/test_config.py b/tests/test_config.py index 463dc95..95482e0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,7 +11,7 @@ AccessLayerObjectConfig, ClientConfig, SparkConfig, - _defaults, + default_config_file, config, ) from datacustomcode.io.base import BaseDataAccessLayer @@ -298,8 +298,8 @@ def test_load_config_from_file(self): os.unlink(temp_path) def test_defaults(self): - # Just verify that _defaults function exists and returns a string path - result = _defaults() + # Just verify that default_config_file function exists and returns a string path + result = default_config_file() assert isinstance(result, str) assert result.endswith("config.yaml") diff --git a/tests/test_einstein_predictions_config_update.py b/tests/test_einstein_predictions_config_update.py new file mode 100644 index 0000000..82602e9 --- /dev/null +++ b/tests/test_einstein_predictions_config_update.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import os +import yaml + +from datacustomcode.einstein_predictions_config import EinsteinPredictionsConfig, EinsteinPredictionsObjectConfig +from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions + + +class TestEinsteinPredictionsConfigUpdate: + def test_update_replaces_config_without_force(self): + config1 = EinsteinPredictionsConfig( + einstein_predictions_config=EinsteinPredictionsObjectConfig( + type_config_name="OldImplementation", + options={"old": True} + ) + ) + + config2 = EinsteinPredictionsConfig( + einstein_predictions_config=EinsteinPredictionsObjectConfig( + type_config_name="NewImplementation", + options={"new": True} + ) + ) + + config1.update(config2) + + assert config1.einstein_predictions_config.type_config_name == "NewImplementation" + assert config1.einstein_predictions_config.options == {"new": True} + + def test_update_respects_force_flag(self): + config1 = EinsteinPredictionsConfig( + einstein_predictions_config=EinsteinPredictionsObjectConfig( + type_config_name="ForcedImplementation", + options={"forced": True}, + force=True + ) + ) + + config2 = EinsteinPredictionsConfig( + einstein_predictions_config=EinsteinPredictionsObjectConfig( + type_config_name="NewImplementation", + options={"new": True} + ) + ) + + config1.update(config2) + + assert config1.einstein_predictions_config.type_config_name == "ForcedImplementation" + assert config1.einstein_predictions_config.options == {"forced": True} + assert config1.einstein_predictions_config.force is True + + +class TestEinsteinPredictionsConfigLoad: + def test_load_from_yaml_file(self): + config_data = { + "einstein_predictions_config": { + "type_config_name": "DefaultEinsteinPredictions" + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config_data, f) + temp_file = f.name + + try: + config = EinsteinPredictionsConfig() + config.load(temp_file) + + assert config.einstein_predictions_config is not None + assert config.einstein_predictions_config.type_config_name == "DefaultEinsteinPredictions" + einstein_predictions = config.einstein_predictions_config.to_object() + assert einstein_predictions is not None + assert isinstance(einstein_predictions, DefaultEinsteinPredictions) + finally: + os.unlink(temp_file) + diff --git a/tests/test_runtime.py b/tests/test_runtime.py index e888d11..6506588 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import threading import pytest @@ -7,6 +5,7 @@ from datacustomcode.file.path.default import DefaultFindFilePath from datacustomcode.function.runtime import Runtime from datacustomcode.llm_gateway.default import DefaultLLMGateway +from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions class TestRuntimeSingleton: @@ -72,6 +71,8 @@ def test_runtime_has_llm_gateway(self): """Test Runtime has llm_gateway property.""" assert hasattr(self.runtime, "llm_gateway") assert isinstance(self.runtime.llm_gateway, DefaultLLMGateway) + assert hasattr(self.runtime, "einstein_predictions") + assert isinstance(self.runtime.einstein_predictions, DefaultEinsteinPredictions) def test_runtime_has_file(self): """Test Runtime has file property.""" diff --git a/tests/test_runtime_einstein_predictions.py b/tests/test_runtime_einstein_predictions.py new file mode 100644 index 0000000..78760e2 --- /dev/null +++ b/tests/test_runtime_einstein_predictions.py @@ -0,0 +1,71 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datacustomcode.einstein_predictions_config import EinsteinPredictionsObjectConfig +from datacustomcode.einstein_predictions.base import EinsteinPredictions +from datacustomcode.einstein_predictions.types import ( + PredictionRequest, + PredictionResponse, + PredictionType, + PredictionColumn, +) +class TestCustomEinsteinPredictionsImplementation: + """Test that other implemenations are supported""" + + def test_custom_implementation_is_discoverable(self): + class CustomEinsteinPredictions(EinsteinPredictions): + CONFIG_NAME = "CustomEinsteinPredictions" + + def __init__(self, custom_param: str = "default", **kwargs): + super().__init__(**kwargs) + self.custom_param = custom_param + + def predict(self, request: PredictionRequest) -> PredictionResponse: + return PredictionResponse( + version="v1", + prediction_type=request.prediction_type, + status_code=200, + data={"results": [{ + "predictedValue": 1 + }]} + ) + + available_names = EinsteinPredictions.available_config_names() + assert "CustomEinsteinPredictions" in available_names + + cls = EinsteinPredictions.subclass_from_config_name("CustomEinsteinPredictions") + assert cls == CustomEinsteinPredictions + + # Verify we can create via config + ep_config = EinsteinPredictionsObjectConfig( + type_config_name="CustomEinsteinPredictions", + options={"custom_param": "my_value"} + ) + instance = ep_config.to_object() + assert isinstance(instance, CustomEinsteinPredictions) + assert instance.custom_param == "my_value" + + request = PredictionRequest( + prediction_type=PredictionType.REGRESSION, + model_api_name="test", + prediction_columns=[ + PredictionColumn(column_name="col1", double_values=[1.0]) + ] + ) + response = instance.predict(request) + assert response.is_success is True + assert response.data["results"] == [{ + "predictedValue": 1 + }] From 732a998d59bbcf8980cd4be43fb7310e309c49ed Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Wed, 22 Apr 2026 17:43:10 -0400 Subject: [PATCH 3/4] fix lint errors --- src/datacustomcode/common_config.py | 26 +++- src/datacustomcode/config.py | 16 ++- .../einstein_predictions/__init__.py | 2 +- .../einstein_predictions/base.py | 6 +- .../einstein_predictions/impl/default.py | 14 +-- .../einstein_predictions/types.py | 118 +++++++++++------- .../einstein_predictions_config.py | 18 ++- src/datacustomcode/function/runtime.py | 7 +- tests/test_config.py | 2 +- tests/test_einstein_predictions.py | 86 ++++++------- ...test_einstein_predictions_config_update.py | 38 +++--- tests/test_runtime.py | 2 +- tests/test_runtime_einstein_predictions.py | 20 ++- 13 files changed, 212 insertions(+), 143 deletions(-) diff --git a/src/datacustomcode/common_config.py b/src/datacustomcode/common_config.py index 5da15bd..78941d0 100644 --- a/src/datacustomcode/common_config.py +++ b/src/datacustomcode/common_config.py @@ -1,12 +1,26 @@ +# Copyright (c) 2025, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os -import yaml +from typing import Any + from pydantic import ( BaseModel, ConfigDict, Field, ) - -from typing import Any +import yaml DEFAULT_CONFIG_NAME = "config.yaml" @@ -18,7 +32,8 @@ def default_config_file() -> str: class ForceableConfig(BaseModel): force: bool = Field( default=False, - description="If True, this takes precedence over parameters passed to the initializer of the client", + description="If True, this takes precedence over parameters passed to the " + "initializer of the client", ) @@ -34,6 +49,9 @@ class BaseObjectConfig(ForceableConfig): class BaseConfig(BaseModel): + def update(self, other: Any) -> "BaseConfig": + raise NotImplementedError("Subclasses must implement update method") + def load(self, config_path: str) -> "BaseConfig": """Load configuration from a YAML file and merge with existing config""" with open(config_path, "r") as f: diff --git a/src/datacustomcode/config.py b/src/datacustomcode/config.py index acd1fed..820c512 100644 --- a/src/datacustomcode/config.py +++ b/src/datacustomcode/config.py @@ -25,8 +25,13 @@ cast, ) -from pydantic import ( - Field, +from pydantic import Field + +from datacustomcode.common_config import ( + BaseConfig, + BaseObjectConfig, + ForceableConfig, + default_config_file, ) # This lets all readers and writers to be findable via config @@ -37,8 +42,6 @@ from datacustomcode.proxy.base import BaseProxyAccessLayer from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002 from datacustomcode.spark.base import BaseSparkSessionProvider -from datacustomcode.common_config import ForceableConfig, BaseObjectConfig, BaseConfig, default_config_file - if TYPE_CHECKING: from pyspark.sql import SparkSession @@ -49,6 +52,7 @@ class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]): type_base: ClassVar[Type[BaseDataAccessLayer]] = BaseDataAccessLayer + def to_object(self, spark: SparkSession) -> _T: type_ = self.type_base.subclass_from_config_name(self.type_config_name) return cast(_T, type_(spark=spark, **self.options)) @@ -75,7 +79,9 @@ class SparkConfig(ForceableConfig): class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]): """Config for proxy clients that take no constructor args (e.g. no spark).""" + type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer + def to_object(self) -> _PX: type_ = self.type_base.subclass_from_config_name(self.type_config_name) return cast(_PX, type_(**self.options)) @@ -83,6 +89,7 @@ def to_object(self) -> _PX: class SparkProviderConfig(BaseObjectConfig, Generic[_P]): type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider + def to_object(self) -> _P: type_ = self.type_base.subclass_from_config_name(self.type_config_name) return cast(_P, type_(**self.options)) @@ -126,6 +133,7 @@ def merge( ) return self + """Global config object. This is the object that makes config accessible globally and globally mutable. diff --git a/src/datacustomcode/einstein_predictions/__init__.py b/src/datacustomcode/einstein_predictions/__init__.py index a9d02c9..4a4b388 100644 --- a/src/datacustomcode/einstein_predictions/__init__.py +++ b/src/datacustomcode/einstein_predictions/__init__.py @@ -19,4 +19,4 @@ __all__ = [ "EinsteinPredictions", "DefaultEinsteinPredictions", -] \ No newline at end of file +] diff --git a/src/datacustomcode/einstein_predictions/base.py b/src/datacustomcode/einstein_predictions/base.py index f0ccee0..f00e2c7 100644 --- a/src/datacustomcode/einstein_predictions/base.py +++ b/src/datacustomcode/einstein_predictions/base.py @@ -15,9 +15,13 @@ from abc import ABC, abstractmethod -from datacustomcode.einstein_predictions.types import (PredictionRequest, PredictionResponse) +from datacustomcode.einstein_predictions.types import ( + PredictionRequest, + PredictionResponse, +) from datacustomcode.mixin import UserExtendableNamedConfigMixin + class EinsteinPredictions(ABC, UserExtendableNamedConfigMixin): CONFIG_NAME: str diff --git a/src/datacustomcode/einstein_predictions/impl/default.py b/src/datacustomcode/einstein_predictions/impl/default.py index da6a8bd..ce8d65a 100644 --- a/src/datacustomcode/einstein_predictions/impl/default.py +++ b/src/datacustomcode/einstein_predictions/impl/default.py @@ -1,4 +1,3 @@ - # Copyright (c) 2025, Salesforce, Inc. # SPDX-License-Identifier: Apache-2 # @@ -17,9 +16,10 @@ from datacustomcode.einstein_predictions.base import EinsteinPredictions from datacustomcode.einstein_predictions.types import ( PredictionRequest, - PredictionResponse + PredictionResponse, ) + class DefaultEinsteinPredictions(EinsteinPredictions): CONFIG_NAME = "DefaultEinsteinPredictions" @@ -28,8 +28,8 @@ def __init__(self, **kwargs): def predict(self, request: PredictionRequest) -> PredictionResponse: return PredictionResponse( - version="v1", - prediction_type=request.prediction_type, - status_code=200, - data={"results": [{"prediction": {"predictedValue": "1"}}]} - ) \ No newline at end of file + version="v1", + prediction_type=request.prediction_type, + status_code=200, + data={"results": [{"prediction": {"predictedValue": "1"}}]}, + ) diff --git a/src/datacustomcode/einstein_predictions/types.py b/src/datacustomcode/einstein_predictions/types.py index 5faaee2..92a7bdd 100644 --- a/src/datacustomcode/einstein_predictions/types.py +++ b/src/datacustomcode/einstein_predictions/types.py @@ -14,18 +14,19 @@ # limitations under the License. from enum import Enum, unique +from typing import ( + Any, + Dict, + Literal, + Optional, +) + from pydantic import ( BaseModel, Field, model_validator, ) -from typing import ( - Literal, - Optional, - Dict, - Any -) @unique class PredictionType(Enum): @@ -35,37 +36,51 @@ class PredictionType(Enum): MULTI_OUTCOME = 4 BINARY_CLASSIFICATION = 5 + class PredictionColumn(BaseModel): column_name: str = Field(min_length=1, description="Column name") - string_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column string values") - double_values: Optional[list[float]] = Field(default = None, min_length=1, description="Column double values") - boolean_values: Optional[list[bool]] = Field(default = None, min_length=1, description="Column boolean values") - date_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column date values") - datetime_values: Optional[list[str]] = Field(default = None, min_length=1, description="Column datetime values") + string_values: Optional[list[str]] = Field( + default=None, min_length=1, description="Column string values" + ) + double_values: Optional[list[float]] = Field( + default=None, min_length=1, description="Column double values" + ) + boolean_values: Optional[list[bool]] = Field( + default=None, min_length=1, description="Column boolean values" + ) + date_values: Optional[list[str]] = Field( + default=None, min_length=1, description="Column date values" + ) + datetime_values: Optional[list[str]] = Field( + default=None, min_length=1, description="Column datetime values" + ) - @model_validator(mode='after') + @model_validator(mode="after") def validate_exactly_one_value_type(self): - set_count = sum([ - self.string_values is not None, - self.double_values is not None, - self.boolean_values is not None, - self.date_values is not None, - self.datetime_values is not None - ]) + set_count = sum( + [ + self.string_values is not None, + self.double_values is not None, + self.boolean_values is not None, + self.date_values is not None, + self.datetime_values is not None, + ] + ) if set_count != 1: raise ValueError("Exactly one value type must be set") return self + class PredictionColumBuilder: def __init__(self) -> None: - self._column_name: str = None - self._string_values: list[str] = None - self._double_values: list[float] = None - self._boolean_values: list[bool] = None - self._date_values: list[str] = None - self._datetime_values: list[str] = None + self._column_name: Optional[str] = None + self._string_values: Optional[list[str]] = None + self._double_values: Optional[list[float]] = None + self._boolean_values: Optional[list[bool]] = None + self._date_values: Optional[list[str]] = None + self._datetime_values: Optional[list[str]] = None def set_column_name(self, column_name: str) -> "PredictionColumBuilder": self._column_name = column_name @@ -79,7 +94,9 @@ def set_double_values(self, double_values: list[float]) -> "PredictionColumBuild self._double_values = double_values return self - def set_boolean_values(self, boolean_values: list[bool]) -> "PredictionColumBuilder": + def set_boolean_values( + self, boolean_values: list[bool] + ) -> "PredictionColumBuilder": self._boolean_values = boolean_values return self @@ -87,37 +104,49 @@ def set_date_values(self, date_values: list[str]) -> "PredictionColumBuilder": self._date_values = date_values return self - def set_datetime_values(self, datetime_values: list[str]) -> "PredictionColumBuilder": + def set_datetime_values( + self, datetime_values: list[str] + ) -> "PredictionColumBuilder": self._datetime_values = datetime_values return self def build(self) -> PredictionColumn: return PredictionColumn( - column_name = self._column_name, - string_values = self._string_values, - double_values = self._double_values, - boolean_values = self._boolean_values, - date_values = self._date_values, - datetime_values = self._datetime_values + column_name=self._column_name, + string_values=self._string_values, + double_values=self._double_values, + boolean_values=self._boolean_values, + date_values=self._date_values, + datetime_values=self._datetime_values, ) + class PredictionRequest(BaseModel): version: Literal["v1"] = Field( default="v1", description="API version, must be 'v1'" ) prediction_type: PredictionType = Field(description="Prediction type") - model_api_name: str = Field(min_length=1, description="API name of the model to use") - prediction_columns: list[PredictionColumn] = Field(min_length=1, description="List of prediction columns") - settings: Optional[Dict[str, Any]] = Field(default=None, description="Settings for the prediction request") + model_api_name: str = Field( + min_length=1, description="API name of the model to use" + ) + prediction_columns: list[PredictionColumn] = Field( + min_length=1, description="List of prediction columns" + ) + settings: Optional[Dict[str, Any]] = Field( + default=None, description="Settings for the prediction request" + ) + class PredictionRequestBuilder: def __init__(self) -> None: - self._prediction_type: PredictionType = None - self._model_api_name: str = None + self._prediction_type: Optional[PredictionType] = None + self._model_api_name: Optional[str] = None self._prediction_columns: list[PredictionColumn] = [] - self._settings: Dict[str, Any] = None + self._settings: Optional[Dict[str, Any]] = None - def set_prediction_type(self, prediction_type: PredictionType) -> "PredictionRequestBuilder": + def set_prediction_type( + self, prediction_type: PredictionType + ) -> "PredictionRequestBuilder": self._prediction_type = prediction_type return self @@ -126,8 +155,7 @@ def set_model(self, model_api_name: str) -> "PredictionRequestBuilder": return self def set_prediction_columns( - self, - prediction_columns: list[PredictionColumn] + self, prediction_columns: list[PredictionColumn] ) -> "PredictionRequestBuilder": self._prediction_columns = prediction_columns return self @@ -141,9 +169,10 @@ def build(self) -> PredictionRequest: prediction_type=self._prediction_type, model_api_name=self._model_api_name, prediction_columns=self._prediction_columns, - settings=self._settings + settings=self._settings, ) - + + class PredictionResponse(BaseModel): version: Literal["v1"] = Field(default="v1", description="API version") prediction_type: PredictionType = Field(description="Prediction type") @@ -153,4 +182,3 @@ class PredictionResponse(BaseModel): @property def is_success(self) -> bool: return self.status_code == 200 - diff --git a/src/datacustomcode/einstein_predictions_config.py b/src/datacustomcode/einstein_predictions_config.py index 3a713fe..4e83164 100644 --- a/src/datacustomcode/einstein_predictions_config.py +++ b/src/datacustomcode/einstein_predictions_config.py @@ -22,24 +22,33 @@ cast, ) +from datacustomcode.common_config import ( + BaseConfig, + BaseObjectConfig, + default_config_file, +) from datacustomcode.einstein_predictions.base import EinsteinPredictions -from datacustomcode.common_config import BaseObjectConfig, BaseConfig, default_config_file _E = TypeVar("_E", bound=EinsteinPredictions) + class EinsteinPredictionsObjectConfig(BaseObjectConfig, Generic[_E]): - type_base: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions + type_base: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract] + def to_object(self) -> _E: type_ = self.type_base.subclass_from_config_name(self.type_config_name) return cast(_E, type_(**self.options)) class EinsteinPredictionsConfig(BaseConfig): - einstein_predictions_config: Union[EinsteinPredictionsObjectConfig[EinsteinPredictions], None] = None + einstein_predictions_config: Union[ + EinsteinPredictionsObjectConfig[EinsteinPredictions], None + ] = None + def update(self, other: "EinsteinPredictionsConfig") -> "EinsteinPredictionsConfig": def merge( config_a: Union[EinsteinPredictionsObjectConfig, None], - config_b: Union[EinsteinPredictionsObjectConfig, None] + config_b: Union[EinsteinPredictionsObjectConfig, None], ) -> Union[EinsteinPredictionsObjectConfig, None]: if config_a is not None and config_a.force: return config_a @@ -52,6 +61,7 @@ def merge( ) return self + # Global Einstein Predictions config instance einstein_predictions_config = EinsteinPredictionsConfig() einstein_predictions_config.load(default_config_file()) diff --git a/src/datacustomcode/function/runtime.py b/src/datacustomcode/function/runtime.py index f1999e0..ffda532 100644 --- a/src/datacustomcode/function/runtime.py +++ b/src/datacustomcode/function/runtime.py @@ -84,7 +84,10 @@ def einstein_predictions(self) -> EinsteinPredictions: if self._einstein_predictions is None: if einstein_predictions_config.einstein_predictions_config is None: raise RuntimeError( - "Einstein Predictions is not configured.Add 'einstein_predictions_config' section to config.yaml" + "Einstein Predictions is not configured. Add " + "'einstein_predictions_config' section to config.yaml" ) - self._einstein_predictions = einstein_predictions_config.einstein_predictions_config.to_object() + self._einstein_predictions = ( + einstein_predictions_config.einstein_predictions_config.to_object() + ) return self._einstein_predictions diff --git a/tests/test_config.py b/tests/test_config.py index 95482e0..0adbbfa 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -11,8 +11,8 @@ AccessLayerObjectConfig, ClientConfig, SparkConfig, - default_config_file, config, + default_config_file, ) from datacustomcode.io.base import BaseDataAccessLayer from datacustomcode.io.reader.base import BaseDataCloudReader diff --git a/tests/test_einstein_predictions.py b/tests/test_einstein_predictions.py index 359353d..e760abc 100644 --- a/tests/test_einstein_predictions.py +++ b/tests/test_einstein_predictions.py @@ -1,21 +1,19 @@ -import pytest from pydantic import ValidationError +import pytest from datacustomcode.einstein_predictions.types import ( + PredictionColumBuilder, PredictionColumn, PredictionRequest, + PredictionRequestBuilder, PredictionResponse, PredictionType, - PredictionColumBuilder, - PredictionRequestBuilder, ) + class TestPredictionColumnValidation: def test_string_values_only(self): - column = PredictionColumn( - column_name="test_col", - string_values=["a", "b", "c"] - ) + column = PredictionColumn(column_name="test_col", string_values=["a", "b", "c"]) assert column.column_name == "test_col" assert column.string_values == ["a", "b", "c"] assert column.double_values is None @@ -24,10 +22,7 @@ def test_string_values_only(self): assert column.datetime_values is None def test_double_values_only(self): - column = PredictionColumn( - column_name="test_col", - double_values=[1.0, 2.5, 3.7] - ) + column = PredictionColumn(column_name="test_col", double_values=[1.0, 2.5, 3.7]) assert column.double_values == [1.0, 2.5, 3.7] assert column.string_values is None assert column.boolean_values is None @@ -36,8 +31,7 @@ def test_double_values_only(self): def test_boolean_values_only(self): column = PredictionColumn( - column_name="test_col", - boolean_values=[True, False, True] + column_name="test_col", boolean_values=[True, False, True] ) assert column.boolean_values == [True, False, True] assert column.string_values is None @@ -45,11 +39,9 @@ def test_boolean_values_only(self): assert column.date_values is None assert column.datetime_values is None - def test_date_values_only(self): column = PredictionColumn( - column_name="test_col", - date_values=["2024-01-01", "2024-01-02"] + column_name="test_col", date_values=["2024-01-01", "2024-01-02"] ) assert column.date_values == ["2024-01-01", "2024-01-02"] assert column.string_values is None @@ -60,7 +52,7 @@ def test_date_values_only(self): def test_datetime_values_only(self): column = PredictionColumn( column_name="test_col", - datetime_values=["2024-01-01T12:00:00", "2024-01-02T13:00:00"] + datetime_values=["2024-01-01T12:00:00", "2024-01-02T13:00:00"], ) assert column.datetime_values == ["2024-01-01T12:00:00", "2024-01-02T13:00:00"] assert column.string_values is None @@ -71,9 +63,7 @@ def test_datetime_values_only(self): def test_no_column_name_raises_error(self): with pytest.raises(ValidationError) as exc_info: PredictionColumn( - column_name="", - string_values=["a", "b"], - double_values=[1.0, 2.0] + column_name="", string_values=["a", "b"], double_values=[1.0, 2.0] ) assert str(exc_info.value) is not None @@ -89,30 +79,31 @@ def test_string_and_double_raises_error(self): PredictionColumn( column_name="test_col", string_values=["a", "b"], - double_values=[1.0, 2.0] + double_values=[1.0, 2.0], ) assert str(exc_info.value) is not None def test_empty_values_raises_error(self): with pytest.raises(ValidationError) as exc_info: - PredictionColumn( - column_name="test_col", - string_values=[] - ) + PredictionColumn(column_name="test_col", string_values=[]) assert str(exc_info.value) is not None + class TestPredictionColumnBuilder: def test_builder_with_string_values(self): - column = (PredictionColumBuilder() - .set_column_name("test_col") - .set_string_values(["a", "b"]) - .build()) + column = ( + PredictionColumBuilder() + .set_column_name("test_col") + .set_string_values(["a", "b"]) + .build() + ) assert column.column_name == "test_col" assert column.string_values == ["a", "b"] + class TestPredictionRequest: def test_request_with_multiple_columns(self): request = PredictionRequest( @@ -121,8 +112,8 @@ def test_request_with_multiple_columns(self): prediction_columns=[ PredictionColumn(column_name="col1", string_values=["a"]), PredictionColumn(column_name="col2", double_values=[1.0]), - PredictionColumn(column_name="col3", boolean_values=[True]) - ] + PredictionColumn(column_name="col3", boolean_values=[True]), + ], ) assert len(request.prediction_columns) == 3 @@ -134,7 +125,7 @@ def test_request_requires_model_api_name(self): model_api_name="", prediction_columns=[ PredictionColumn(column_name="col1", double_values=[1.0]) - ] + ], ) def test_request_requires_prediction_columns(self): @@ -142,26 +133,27 @@ def test_request_requires_prediction_columns(self): PredictionRequest( prediction_type=PredictionType.REGRESSION, model_api_name="model", - prediction_columns=[] + prediction_columns=[], ) + class TestPredictionRequestBuilder: def test_builder_creates_valid_request(self): - request = (PredictionRequestBuilder() - .set_prediction_type(PredictionType.CLUSTERING) - .set_model("cluster_model") - .set_prediction_columns([ - PredictionColumn(column_name="test_col", double_values=[1.0]) - ]) - .set_settings({ - 'maxTopContributors': 20 - }) - .build()) + request = ( + PredictionRequestBuilder() + .set_prediction_type(PredictionType.CLUSTERING) + .set_model("cluster_model") + .set_prediction_columns( + [PredictionColumn(column_name="test_col", double_values=[1.0])] + ) + .set_settings({"maxTopContributors": 20}) + .build() + ) assert request.prediction_type == PredictionType.CLUSTERING assert request.model_api_name == "cluster_model" assert len(request.prediction_columns) == 1 - assert request.settings == { 'maxTopContributors': 20 } + assert request.settings == {"maxTopContributors": 20} class TestPredictionResponse: @@ -170,7 +162,7 @@ def test_successful_response(self): version="v1", prediction_type=PredictionType.REGRESSION, status_code=200, - data={"results": [{"prediction": {"value": 42.5}}]} + data={"results": [{"prediction": {"value": 42.5}}]}, ) assert response.is_success @@ -182,8 +174,8 @@ def test_failed_response(self): version="v1", prediction_type=PredictionType.REGRESSION, status_code=500, - data={"error": "Internal server error"} + data={"error": "Internal server error"}, ) assert not response.is_success - assert response.status_code == 500 \ No newline at end of file + assert response.status_code == 500 diff --git a/tests/test_einstein_predictions_config_update.py b/tests/test_einstein_predictions_config_update.py index 82602e9..12bdc62 100644 --- a/tests/test_einstein_predictions_config_update.py +++ b/tests/test_einstein_predictions_config_update.py @@ -13,33 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tempfile import os +import tempfile + import yaml -from datacustomcode.einstein_predictions_config import EinsteinPredictionsConfig, EinsteinPredictionsObjectConfig from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions +from datacustomcode.einstein_predictions_config import ( + EinsteinPredictionsConfig, + EinsteinPredictionsObjectConfig, +) class TestEinsteinPredictionsConfigUpdate: def test_update_replaces_config_without_force(self): config1 = EinsteinPredictionsConfig( einstein_predictions_config=EinsteinPredictionsObjectConfig( - type_config_name="OldImplementation", - options={"old": True} + type_config_name="OldImplementation", options={"old": True} ) ) config2 = EinsteinPredictionsConfig( einstein_predictions_config=EinsteinPredictionsObjectConfig( - type_config_name="NewImplementation", - options={"new": True} + type_config_name="NewImplementation", options={"new": True} ) ) config1.update(config2) - assert config1.einstein_predictions_config.type_config_name == "NewImplementation" + assert ( + config1.einstein_predictions_config.type_config_name == "NewImplementation" + ) assert config1.einstein_predictions_config.options == {"new": True} def test_update_respects_force_flag(self): @@ -47,20 +51,22 @@ def test_update_respects_force_flag(self): einstein_predictions_config=EinsteinPredictionsObjectConfig( type_config_name="ForcedImplementation", options={"forced": True}, - force=True + force=True, ) ) config2 = EinsteinPredictionsConfig( einstein_predictions_config=EinsteinPredictionsObjectConfig( - type_config_name="NewImplementation", - options={"new": True} + type_config_name="NewImplementation", options={"new": True} ) ) config1.update(config2) - assert config1.einstein_predictions_config.type_config_name == "ForcedImplementation" + assert ( + config1.einstein_predictions_config.type_config_name + == "ForcedImplementation" + ) assert config1.einstein_predictions_config.options == {"forced": True} assert config1.einstein_predictions_config.force is True @@ -72,8 +78,8 @@ def test_load_from_yaml_file(self): "type_config_name": "DefaultEinsteinPredictions" } } - - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: yaml.dump(config_data, f) temp_file = f.name @@ -82,10 +88,12 @@ def test_load_from_yaml_file(self): config.load(temp_file) assert config.einstein_predictions_config is not None - assert config.einstein_predictions_config.type_config_name == "DefaultEinsteinPredictions" + assert ( + config.einstein_predictions_config.type_config_name + == "DefaultEinsteinPredictions" + ) einstein_predictions = config.einstein_predictions_config.to_object() assert einstein_predictions is not None assert isinstance(einstein_predictions, DefaultEinsteinPredictions) finally: os.unlink(temp_file) - diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 6506588..a8e1879 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -2,10 +2,10 @@ import pytest +from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions from datacustomcode.file.path.default import DefaultFindFilePath from datacustomcode.function.runtime import Runtime from datacustomcode.llm_gateway.default import DefaultLLMGateway -from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions class TestRuntimeSingleton: diff --git a/tests/test_runtime_einstein_predictions.py b/tests/test_runtime_einstein_predictions.py index 78760e2..cd05d03 100644 --- a/tests/test_runtime_einstein_predictions.py +++ b/tests/test_runtime_einstein_predictions.py @@ -13,16 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datacustomcode.einstein_predictions_config import EinsteinPredictionsObjectConfig from datacustomcode.einstein_predictions.base import EinsteinPredictions from datacustomcode.einstein_predictions.types import ( + PredictionColumn, PredictionRequest, PredictionResponse, PredictionType, - PredictionColumn, ) +from datacustomcode.einstein_predictions_config import EinsteinPredictionsObjectConfig + + class TestCustomEinsteinPredictionsImplementation: - """Test that other implemenations are supported""" + """Test that other implementations are supported""" def test_custom_implementation_is_discoverable(self): class CustomEinsteinPredictions(EinsteinPredictions): @@ -37,9 +39,7 @@ def predict(self, request: PredictionRequest) -> PredictionResponse: version="v1", prediction_type=request.prediction_type, status_code=200, - data={"results": [{ - "predictedValue": 1 - }]} + data={"results": [{"predictedValue": 1}]}, ) available_names = EinsteinPredictions.available_config_names() @@ -51,7 +51,7 @@ def predict(self, request: PredictionRequest) -> PredictionResponse: # Verify we can create via config ep_config = EinsteinPredictionsObjectConfig( type_config_name="CustomEinsteinPredictions", - options={"custom_param": "my_value"} + options={"custom_param": "my_value"}, ) instance = ep_config.to_object() assert isinstance(instance, CustomEinsteinPredictions) @@ -62,10 +62,8 @@ def predict(self, request: PredictionRequest) -> PredictionResponse: model_api_name="test", prediction_columns=[ PredictionColumn(column_name="col1", double_values=[1.0]) - ] + ], ) response = instance.predict(request) assert response.is_success is True - assert response.data["results"] == [{ - "predictedValue": 1 - }] + assert response.data["results"] == [{"predictedValue": 1}] From ef5fe0cbea4d81d0248ce63434a6f69d1e5a25c1 Mon Sep 17 00:00:00 2001 From: atulikumwenayo Date: Wed, 22 Apr 2026 19:35:37 -0400 Subject: [PATCH 4/4] make the class appropriately abstract --- src/datacustomcode/common_config.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/datacustomcode/common_config.py b/src/datacustomcode/common_config.py index 78941d0..a3bdbee 100644 --- a/src/datacustomcode/common_config.py +++ b/src/datacustomcode/common_config.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod import os from typing import Any @@ -48,9 +49,9 @@ class BaseObjectConfig(ForceableConfig): ) -class BaseConfig(BaseModel): - def update(self, other: Any) -> "BaseConfig": - raise NotImplementedError("Subclasses must implement update method") +class BaseConfig(ABC, BaseModel): + @abstractmethod + def update(self, other: Any) -> "BaseConfig": ... def load(self, config_path: str) -> "BaseConfig": """Load configuration from a YAML file and merge with existing config"""