Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/datacustomcode/common_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import os
from typing import Any

from pydantic import (
BaseModel,
ConfigDict,
Field,
)
import yaml

DEFAULT_CONFIG_NAME = "config.yaml"


def default_config_file() -> str:
return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME)


class ForceableConfig(BaseModel):
force: bool = Field(
default=False,
description="If True, this takes precedence over parameters passed to the "
"initializer of the client",
)


class BaseObjectConfig(ForceableConfig):
model_config = ConfigDict(validate_default=True, extra="forbid")
type_config_name: str = Field(
description="The config name of the object to create",
)
options: dict[str, Any] = Field(
default_factory=dict,
description="Options passed to the constructor.",
)


class BaseConfig(ABC, BaseModel):
@abstractmethod
def update(self, other: Any) -> "BaseConfig": ...

def load(self, config_path: str) -> "BaseConfig":
"""Load configuration from a YAML file and merge with existing config"""
with open(config_path, "r") as f:
config_data = yaml.safe_load(f)

loaded_config = self.__class__.model_validate(config_data)
self.update(loaded_config)
return self
77 changes: 13 additions & 64 deletions src/datacustomcode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
from __future__ import annotations

import os
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -26,12 +25,14 @@
cast,
)

from pydantic import (
BaseModel,
ConfigDict,
Field,
from pydantic import Field

from datacustomcode.common_config import (
BaseConfig,
BaseObjectConfig,
ForceableConfig,
default_config_file,
)
import yaml

# This lets all readers and writers to be findable via config
from datacustomcode.io import * # noqa: F403
Expand All @@ -42,36 +43,15 @@
from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002
from datacustomcode.spark.base import BaseSparkSessionProvider

DEFAULT_CONFIG_NAME = "config.yaml"


if TYPE_CHECKING:
from pyspark.sql import SparkSession


class ForceableConfig(BaseModel):
force: bool = Field(
default=False,
description="If True, this takes precedence over parameters passed to the "
"initializer of the client.",
)


_T = TypeVar("_T", bound="BaseDataAccessLayer")


class AccessLayerObjectConfig(ForceableConfig, Generic[_T]):
model_config = ConfigDict(validate_default=True, extra="forbid")
class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]):
type_base: ClassVar[Type[BaseDataAccessLayer]] = BaseDataAccessLayer
type_config_name: str = Field(
description="The config name of the object to create. "
"For metrics, this would might be 'ipmnormal'. For custom classes, you can "
"assign a name to a class variable `CONFIG_NAME` and reference it here.",
)
options: dict[str, Any] = Field(
default_factory=dict,
description="Options passed to the constructor.",
)

def to_object(self, spark: SparkSession) -> _T:
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
Expand All @@ -97,35 +77,25 @@ class SparkConfig(ForceableConfig):
_PX = TypeVar("_PX", bound=BaseProxyAccessLayer)


class ProxyAccessLayerObjectConfig(ForceableConfig, Generic[_PX]):
class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]):
"""Config for proxy clients that take no constructor args (e.g. no spark)."""

model_config = ConfigDict(validate_default=True, extra="forbid")
type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer
type_config_name: str = Field(
description="CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient').",
)
options: dict[str, Any] = Field(default_factory=dict)

def to_object(self) -> _PX:
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
return cast(_PX, type_(**self.options))


class SparkProviderConfig(ForceableConfig, Generic[_P]):
model_config = ConfigDict(validate_default=True, extra="forbid")
class SparkProviderConfig(BaseObjectConfig, Generic[_P]):
type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider
type_config_name: str = Field(
description="CONFIG_NAME of the Spark session provider."
)
options: dict[str, Any] = Field(default_factory=dict)

def to_object(self) -> _P:
type_ = self.type_base.subclass_from_config_name(self.type_config_name)
return cast(_P, type_(**self.options))


class ClientConfig(BaseModel):
class ClientConfig(BaseConfig):
reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None
Expand Down Expand Up @@ -163,31 +133,10 @@ def merge(
)
return self

def load(self, config_path: str) -> ClientConfig:
"""Load a config from a file and update this config with it.

Args:
config_path: The path to the config file

Returns:
Self, with updated values from the loaded config.
"""
with open(config_path, "r") as f:
config_data = yaml.safe_load(f)
loaded_config = ClientConfig.model_validate(config_data)

return self.update(loaded_config)


config = ClientConfig()
"""Global config object.

This is the object that makes config accessible globally and globally mutable.
"""


def _defaults() -> str:
return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME)


config.load(_defaults())
config = ClientConfig()
config.load(default_config_file())
4 changes: 4 additions & 0 deletions src/datacustomcode/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ proxy_config:
type_config_name: LocalProxyClientProvider
options:
credentials_profile: default

einstein_predictions_config:
type_config_name: DefaultEinsteinPredictions
options: {}
22 changes: 22 additions & 0 deletions src/datacustomcode/einstein_predictions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datacustomcode.einstein_predictions.base import EinsteinPredictions
from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions

__all__ = [
"EinsteinPredictions",
"DefaultEinsteinPredictions",
]
32 changes: 32 additions & 0 deletions src/datacustomcode/einstein_predictions/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod

from datacustomcode.einstein_predictions.types import (
PredictionRequest,
PredictionResponse,
)
from datacustomcode.mixin import UserExtendableNamedConfigMixin


class EinsteinPredictions(ABC, UserExtendableNamedConfigMixin):
CONFIG_NAME: str

def __init__(self, **kwargs):
pass

@abstractmethod
def predict(self, request: PredictionRequest) -> PredictionResponse: ...
35 changes: 35 additions & 0 deletions src/datacustomcode/einstein_predictions/impl/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2025, Salesforce, Inc.
# SPDX-License-Identifier: Apache-2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datacustomcode.einstein_predictions.base import EinsteinPredictions
from datacustomcode.einstein_predictions.types import (
PredictionRequest,
PredictionResponse,
)


class DefaultEinsteinPredictions(EinsteinPredictions):
CONFIG_NAME = "DefaultEinsteinPredictions"

def __init__(self, **kwargs):
super().__init__(**kwargs)

def predict(self, request: PredictionRequest) -> PredictionResponse:
return PredictionResponse(
version="v1",
prediction_type=request.prediction_type,
status_code=200,
data={"results": [{"prediction": {"predictedValue": "1"}}]},
)
Loading
Loading