Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions packages/bigframes/bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,87 @@ def score(
return series_list[0]._apply_nary_op(operator, series_list[1:])


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def similarity(
content1: str | series.Series | pd.Series,
content2: str | series.Series | pd.Series,
*,
endpoint: str | None = None,
model: str | None = None,
model_params: Mapping[Any, Any] | None = None,
connection_id: str | None = None,
) -> series.Series:
"""
Returns a FLOAT64 value that represents the cosine similarity between the two inputs.

**Examples:**

>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> df = bpd.DataFrame({'word': ['happy', 'sad']})
>>> bbq.ai.similarity(df['word'], 'glad', endpoint='text-embedding-005') # doctest: +SKIP
0 0.916601
1 0.660579

Args:
content1 (str | Series):
A string or series that provides the first value to compare. Both a BigFrames Series or a pandas Series are allowed.
content2 (str | Series):
A string or series that provides the second value to compare. Both a BigFrames Series or a pandas Series are allowed.
endpoint (str, optional):
Specifies the Vertex AI endpoint to use for the text embedding model.
If you specify the model name, such as `'text-embedding-005'`, rather than a URL, then BigQuery ML automatically identifies the model and uses the model's full endpoint.
model (str, optional):
Specifies a built-in text embedding model. The only supported value is the embeddinggemma-300m model.
Comment thread
sycai marked this conversation as resolved.
If you specify this parameter, you can't specify the `endpoint`, `model_params`, or `connection_id` parameters.
model_params (Mapping[Any, Any], optional):
Provides additional parameters to the model. You can use any of the parameters object fields.
One of these fields, `outputDimensionality`, lets you specify the number of dimensions to use when generating embeddings.
connection_id (str, optional):
Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`.

Returns:
bigframes.series.Series: A new series of FLOAT64 values representing the cosine similarity.
"""
if model is not None:
if any(x is not None for x in [endpoint, model_params, connection_id]):
raise ValueError(
"If 'model' is specified, you cannot specify 'endpoint', 'model_params', or 'connection_id'."
)
elif endpoint is None:
raise ValueError("You must specify either 'model' or 'endpoint'.")

operator = ai_ops.AISimilarity(
endpoint=endpoint,
model=model,
model_params=json.dumps(model_params) if model_params else None,
connection_id=connection_id,
)

# Find a unifying session for the subsequent operations.
bf_session = None
if isinstance(content1, series.Series):
bf_session = content1._session
elif isinstance(content2, series.Series):
bf_session = content2._session

if isinstance(content1, str) and isinstance(content2, str):
content1 = series.Series([content1], session=bf_session)
return content1._apply_binary_op(content2, operator)
elif isinstance(content1, str):
# content2 must be a series
content2 = convert.to_bf_series(
content2, default_index=None, session=bf_session
)
return content2._apply_binary_op(content1, operator)
else:
# content1 must be a series.
content1 = convert.to_bf_series(
content1, default_index=None, session=bf_session
)
return content1._apply_binary_op(content2, operator)


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def forecast(
df: dataframe.DataFrame | pd.DataFrame,
Expand Down
2 changes: 2 additions & 0 deletions packages/bigframes/bigframes/bigquery/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
generate_text,
if_,
score,
similarity,
)

__all__ = [
Expand All @@ -82,4 +83,5 @@
"generate_text",
"if_",
"score",
"similarity",
]
Original file line number Diff line number Diff line change
Expand Up @@ -1992,6 +1992,20 @@ def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructVal
).to_expr()


@scalar_op_compiler.register_binary_op(ops.AISimilarity, pass_op=True)
def ai_similarity(
content1: ibis_types.Value, content2: ibis_types.Value, op: ops.AISimilarity
) -> ibis_types.Value:
return ai_ops.AISimilarity(
content1, # type: ignore
content2, # type: ignore
op.endpoint, # type: ignore
op.model, # type: ignore
op.model_params, # type: ignore
op.connection_id, # type: ignore
).to_expr()


def _construct_prompt(
col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None]
) -> ibis_types.StructValue:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr

register_nary_op = expression_compiler.expression_compiler.register_nary_op
register_binary_op = expression_compiler.expression_compiler.register_binary_op


@register_nary_op(ops.AIGenerate, pass_op=True)
Expand Down Expand Up @@ -76,6 +77,16 @@ def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression:
return sge.func("AI.SCORE", *args)


@register_binary_op(ops.AISimilarity, pass_op=True)
def _(content1: TypedExpr, content2: TypedExpr, op: ops.AISimilarity) -> sge.Expression:
args = [
sge.Kwarg(this="content1", expression=content1.expr),
sge.Kwarg(this="content2", expression=content2.expr),
] + _construct_named_args(op)

return sge.func("AI.SIMILARITY", *args)


def _construct_prompt(
exprs: tuple[TypedExpr, ...],
prompt_context: tuple[str | None, ...],
Expand All @@ -94,7 +105,7 @@ def _construct_prompt(
return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt))


def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
args = []

op_args = asdict(op)
Expand Down
2 changes: 2 additions & 0 deletions packages/bigframes/bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
AIGenerateInt,
AIIf,
AIScore,
AISimilarity,
)
from bigframes.operations.array_ops import (
ArrayIndexOp,
Expand Down Expand Up @@ -436,6 +437,7 @@
"AIGenerateInt",
"AIIf",
"AIScore",
"AISimilarity",
# Numpy ops mapping
"NUMPY_TO_BINOP",
"NUMPY_TO_OP",
Expand Down
13 changes: 13 additions & 0 deletions packages/bigframes/bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,16 @@ class AIScore(base_ops.NaryOp):

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return dtypes.FLOAT_DTYPE


@dataclasses.dataclass(frozen=True)
class AISimilarity(base_ops.BinaryOp):
name: ClassVar[str] = "ai_similarity"

endpoint: str | None
model: str | None
model_params: str | None
connection_id: str | None

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return dtypes.FLOAT_DTYPE
48 changes: 48 additions & 0 deletions packages/bigframes/tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,5 +370,53 @@ def test_forecast_w_params(time_series_df_default_index: dataframe.DataFrame):
)


def test_ai_similarity(session):
s1 = bpd.Series(["happy", "sad"], session=session)
s2 = pd.Series(["glad", "angry"])

result = bbq.ai.similarity(s1, s2, endpoint="text-embedding-005")

assert _contains_no_nulls(result)
assert result.dtype == dtypes.FLOAT_DTYPE


def test_ai_similarity_one_content_is_string_literal(session):
s1 = "happy"
s2 = bpd.Series(["glad", "angry"], session=session)

result = bbq.ai.similarity(s1, s2, model="embeddinggemma-300m")
Comment thread
sycai marked this conversation as resolved.

assert _contains_no_nulls(result)
assert result.dtype == dtypes.FLOAT_DTYPE


def test_ai_similarity_both_contents_are_string_literals(session):
s1 = "happy"
s2 = "glad"

result = bbq.ai.similarity(s1, s2, endpoint="text-embedding-005")

assert _contains_no_nulls(result)
assert result.dtype == dtypes.FLOAT_DTYPE


def test_ai_similarity_no_endpoint_or_model__raises_error(session):
s1 = bpd.Series(["happy", "sad"], session=session)
s2 = bpd.Series(["glad", "angry"], session=session)

with pytest.raises(ValueError):
bbq.ai.similarity(s1, s2)


def test_ai_similarity_both_endpoint_and_model__raises_error(session):
s1 = "happy"
s2 = "glad"

with pytest.raises(ValueError):
bbq.ai.similarity(
s1, s2, endpoint="text-embedding-005", model="embeddinggemma-300m"
)


def _contains_no_nulls(s: series.Series) -> bool:
return len(s) == s.count()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
AI.SIMILARITY(content1 => `string_col`, content2 => `string_col`, endpoint => 'text-embedding-005') AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
AI.SIMILARITY(
content1 => `string_col`,
content2 => `string_col`,
endpoint => 'text-embedding-005',
connection_id => 'bigframes-dev.us.bigframes-default-connection'
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
AI.SIMILARITY(content1 => `string_col`, content2 => `string_col`, model => 'embeddinggemma-300m') AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
SELECT
AI.SIMILARITY(
content1 => `string_col`,
content2 => `string_col`,
endpoint => 'text-embedding-005',
model_params => JSON '{"outputDimensionality": 256}'
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,55 @@ def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id)
)

snapshot.assert_match(sql, "out.sql")


@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID])
def test_ai_similarity(scalar_types_df: dataframe.DataFrame, snapshot, connection_id):
col_name = "string_col"

op = ops.AISimilarity(
endpoint="text-embedding-005",
model=None,
model_params=None,
connection_id=connection_id,
)

sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")


def test_ai_similarity_with_model(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AISimilarity(
endpoint=None,
model="embeddinggemma-300m",
Comment thread
sycai marked this conversation as resolved.
model_params=None,
connection_id=None,
)

sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")


def test_ai_similarity_with_model_param(scalar_types_df: dataframe.DataFrame, snapshot):
col_name = "string_col"

op = ops.AISimilarity(
endpoint="text-embedding-005",
model=None,
model_params=json.dumps({"outputDimensionality": 256}),
connection_id=None,
)

sql = utils._apply_ops_to_sql(
scalar_types_df, [op.as_expr(col_name, col_name)], ["result"]
)

snapshot.assert_match(sql, "out.sql")
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,9 @@ def visit_AIClassify(self, op, **kwargs):
def visit_AIScore(self, op, **kwargs):
return sge.func("AI.SCORE", *self._compile_ai_args(**kwargs))

def visit_AISimilarity(self, op, **kwargs):
return sge.func("AI.SIMILARITY", *self._compile_ai_args(**kwargs))

def _compile_ai_args(self, **kwargs):
args = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,20 @@ class AIScore(Value):

shape = rlz.shape_like("prompt")


@public
class AISimilarity(Value):
"""Calculate the similarity between two contents"""

content1: Value
content2: Value
endpoint: Optional[Value[dt.String]]
model: Optional[Value[dt.String]]
model_params: Optional[Value[dt.String]]
connection_id: Optional[Value[dt.String]]

shape = rlz.shape_like("content1")

@attribute
def dtype(self) -> dt.Struct:
return dt.float64
Loading