From ee16373192ee80e2ee83df3c5d9692df62fd7e93 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 22 Apr 2026 23:54:55 +0000 Subject: [PATCH 1/2] feat(bigframes): implement ai.similarity --- .../bigframes/bigquery/_operations/ai.py | 81 +++++++++++++++++++ packages/bigframes/bigframes/bigquery/ai.py | 2 + .../ibis_compiler/scalar_op_registry.py | 14 ++++ .../compile/sqlglot/expressions/ai_ops.py | 13 ++- .../bigframes/operations/__init__.py | 2 + .../bigframes/bigframes/operations/ai_ops.py | 13 +++ .../tests/system/small/bigquery/test_ai.py | 48 +++++++++++ .../test_ai_similarity/None/out.sql | 3 + .../out.sql | 8 ++ .../test_ai_similarity_with_model/out.sql | 3 + .../out.sql | 8 ++ .../sqlglot/expressions/test_ai_ops.py | 52 ++++++++++++ .../sql/compilers/bigquery/__init__.py | 3 + .../ibis/expr/operations/ai_ops.py | 14 ++++ 14 files changed, 263 insertions(+), 1 deletion(-) create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/None/out.sql create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/bigframes-dev.us.bigframes-default-connection/out.sql create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model/out.sql create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model_param/out.sql diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index 7a509d4f95ff..03d49dbb15ee 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -869,6 +869,87 @@ def score( return series_list[0]._apply_nary_op(operator, series_list[1:]) +@log_adapter.method_logger(custom_base_name="bigquery_ai") +def similarity( + content1: str | series.Series | pd.Series, + content2: str | series.Series | pd.Series, + *, + endpoint: str | None = None, + model: str | None = None, + model_params: Mapping[Any, Any] | None = None, + connection_id: str | None = None, +) -> series.Series: + """ + Returns a FLOAT64 value that represents the cosine similarity between the two inputs. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> df = bpd.DataFrame({'word': ['happy', 'sad']}) + >>> bbq.ai.similarity(df['word'], 'glad', endpoint='text-embedding-005') # doctest: +SKIP + 0 0.916601 + 1 0.660579 + + Args: + content1 (str | Series): + A string or series that provides the first value to compare. Both a BigFrames Series or a pandas Series are allowed. + content2 (str | Series): + A string or series that provides the second value to compare. Both a BigFrames Series or a pandas Series are allowed. + endpoint (str, optional): + Specifies the Vertex AI endpoint to use for the text embedding model. + If you specify the model name, such as `'text-embedding-005'`, rather than a URL, then BigQuery ML automatically identifies the model and uses the model's full endpoint. + model (str, optional): + Specifies a built-in text embedding model. The only supported value is the embeddinggemma-300m model. + If you specify this parameter, you can't specify the `endpoint`, `model_params`, or `connection_id` parameters. + model_params (Mapping[Any, Any], optional): + Provides additional parameters to the model. You can use any of the parameters object fields. + One of these fields, `outputDimensionality`, lets you specify the number of dimensions to use when generating embeddings. + connection_id (str, optional): + Specifies the connection to use to communicate with the model. For example, `myproject.us.myconnection`. + + Returns: + bigframes.series.Series: A new series of FLOAT64 values representing the cosine similarity. + """ + if model is not None: + if any(x is not None for x in [endpoint, model_params, connection_id]): + raise ValueError( + "If 'model' is specified, you cannot specify 'endpoint', 'model_params', or 'connection_id'." + ) + elif endpoint is None: + raise ValueError("You must specify either 'model' or 'endpoint'.") + + operator = ai_ops.AISimilarity( + endpoint=endpoint, + model=model, + model_params=json.dumps(model_params) if model_params else None, + connection_id=connection_id, + ) + + # Find a unifying session for the subsequent operations. + bf_session = None + if isinstance(content1, series.Series): + bf_session = content1._session + elif isinstance(content2, series.Series): + bf_session = content2._session + + if isinstance(content1, str) and isinstance(content2, str): + content1 = series.Series([content1], session=bf_session) + return content1._apply_binary_op(content2, operator) + elif isinstance(content1, str): + # content2 must be a series + content2 = convert.to_bf_series( + content2, default_index=None, session=bf_session + ) + return content2._apply_binary_op(content1, operator) + else: + # content1 must be a series. + content1 = convert.to_bf_series( + content1, default_index=None, session=bf_session + ) + return content1._apply_binary_op(content2, operator) + + @log_adapter.method_logger(custom_base_name="bigquery_ai") def forecast( df: dataframe.DataFrame | pd.DataFrame, diff --git a/packages/bigframes/bigframes/bigquery/ai.py b/packages/bigframes/bigframes/bigquery/ai.py index 25a7df778127..675720d27bd2 100644 --- a/packages/bigframes/bigframes/bigquery/ai.py +++ b/packages/bigframes/bigframes/bigquery/ai.py @@ -68,6 +68,7 @@ generate_text, if_, score, + similarity, ) __all__ = [ @@ -82,4 +83,5 @@ "generate_text", "if_", "score", + "similarity", ] diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 26ba0d8cb4b4..5ccc0e9d1a85 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1992,6 +1992,20 @@ def ai_score(*values: ibis_types.Value, op: ops.AIScore) -> ibis_types.StructVal ).to_expr() +@scalar_op_compiler.register_binary_op(ops.AISimilarity, pass_op=True) +def ai_similarity( + content1: ibis_types.Value, content2: ibis_types.Value, op: ops.AISimilarity +) -> ibis_types.Value: + return ai_ops.AISimilarity( + content1, # type: ignore + content2, # type: ignore + op.endpoint, # type: ignore + op.model, # type: ignore + op.model_params, # type: ignore + op.connection_id, # type: ignore + ).to_expr() + + def _construct_prompt( col_refs: tuple[ibis_types.Value], prompt_context: tuple[str | None] ) -> ibis_types.StructValue: diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py index eaeeb0b3a56d..d2a3de33c0b1 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -23,6 +23,7 @@ from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr register_nary_op = expression_compiler.expression_compiler.register_nary_op +register_binary_op = expression_compiler.expression_compiler.register_binary_op @register_nary_op(ops.AIGenerate, pass_op=True) @@ -76,6 +77,16 @@ def _(*exprs: TypedExpr, op: ops.AIScore) -> sge.Expression: return sge.func("AI.SCORE", *args) +@register_binary_op(ops.AISimilarity, pass_op=True) +def _(content1: TypedExpr, content2: TypedExpr, op: ops.AISimilarity) -> sge.Expression: + args = [ + sge.Kwarg(this="content1", expression=content1.expr), + sge.Kwarg(this="content2", expression=content2.expr), + ] + _construct_named_args(op) + + return sge.func("AI.SIMILARITY", *args) + + def _construct_prompt( exprs: tuple[TypedExpr, ...], prompt_context: tuple[str | None, ...], @@ -94,7 +105,7 @@ def _construct_prompt( return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt)) -def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: +def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]: args = [] op_args = asdict(op) diff --git a/packages/bigframes/bigframes/operations/__init__.py b/packages/bigframes/bigframes/operations/__init__.py index f35a07271bd7..58724b400e78 100644 --- a/packages/bigframes/bigframes/operations/__init__.py +++ b/packages/bigframes/bigframes/operations/__init__.py @@ -22,6 +22,7 @@ AIGenerateInt, AIIf, AIScore, + AISimilarity, ) from bigframes.operations.array_ops import ( ArrayIndexOp, @@ -436,6 +437,7 @@ "AIGenerateInt", "AIIf", "AIScore", + "AISimilarity", # Numpy ops mapping "NUMPY_TO_BINOP", "NUMPY_TO_OP", diff --git a/packages/bigframes/bigframes/operations/ai_ops.py b/packages/bigframes/bigframes/operations/ai_ops.py index b20314fe2321..e53751744dc4 100644 --- a/packages/bigframes/bigframes/operations/ai_ops.py +++ b/packages/bigframes/bigframes/operations/ai_ops.py @@ -150,3 +150,16 @@ class AIScore(base_ops.NaryOp): def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: return dtypes.FLOAT_DTYPE + + +@dataclasses.dataclass(frozen=True) +class AISimilarity(base_ops.BinaryOp): + name: ClassVar[str] = "ai_similarity" + + endpoint: str | None + model: str | None + model_params: str | None + connection_id: str | None + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return dtypes.FLOAT_DTYPE diff --git a/packages/bigframes/tests/system/small/bigquery/test_ai.py b/packages/bigframes/tests/system/small/bigquery/test_ai.py index 16e9cca9f136..a3939854a090 100644 --- a/packages/bigframes/tests/system/small/bigquery/test_ai.py +++ b/packages/bigframes/tests/system/small/bigquery/test_ai.py @@ -370,5 +370,53 @@ def test_forecast_w_params(time_series_df_default_index: dataframe.DataFrame): ) +def test_ai_similarity(session): + s1 = bpd.Series(["happy", "sad"], session=session) + s2 = pd.Series(["glad", "angry"]) + + result = bbq.ai.similarity(s1, s2, endpoint="text-embedding-005") + + assert _contains_no_nulls(result) + assert result.dtype == dtypes.FLOAT_DTYPE + + +def test_ai_similarity_one_content_is_string_literal(session): + s1 = "happy" + s2 = bpd.Series(["glad", "angry"], session=session) + + result = bbq.ai.similarity(s1, s2, model="embeddinggemma-300m") + + assert _contains_no_nulls(result) + assert result.dtype == dtypes.FLOAT_DTYPE + + +def test_ai_similarity_both_contents_are_string_literals(session): + s1 = "happy" + s2 = "glad" + + result = bbq.ai.similarity(s1, s2, endpoint="text-embedding-005") + + assert _contains_no_nulls(result) + assert result.dtype == dtypes.FLOAT_DTYPE + + +def test_ai_similarity_no_endpoint_or_model__raises_error(session): + s1 = bpd.Series(["happy", "sad"], session=session) + s2 = bpd.Series(["glad", "angry"], session=session) + + with pytest.raises(ValueError): + bbq.ai.similarity(s1, s2) + + +def test_ai_similarity_both_endpoint_and_model__raises_error(session): + s1 = "happy" + s2 = "glad" + + with pytest.raises(ValueError): + bbq.ai.similarity( + s1, s2, endpoint="text-embedding-005", model="embeddinggemma-300m" + ) + + def _contains_no_nulls(s: series.Series) -> bool: return len(s) == s.count() diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/None/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/None/out.sql new file mode 100644 index 000000000000..1df70aaf18e3 --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/None/out.sql @@ -0,0 +1,3 @@ +SELECT + AI.SIMILARITY(content1 => `string_col`, content2 => `string_col`, endpoint => 'text-embedding-005') AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/bigframes-dev.us.bigframes-default-connection/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/bigframes-dev.us.bigframes-default-connection/out.sql new file mode 100644 index 000000000000..db57188ffa0a --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity/bigframes-dev.us.bigframes-default-connection/out.sql @@ -0,0 +1,8 @@ +SELECT + AI.SIMILARITY( + content1 => `string_col`, + content2 => `string_col`, + endpoint => 'text-embedding-005', + connection_id => 'bigframes-dev.us.bigframes-default-connection' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model/out.sql new file mode 100644 index 000000000000..704f9f944914 --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model/out.sql @@ -0,0 +1,3 @@ +SELECT + AI.SIMILARITY(content1 => `string_col`, content2 => `string_col`, model => 'embeddinggemma-300m') AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model_param/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model_param/out.sql new file mode 100644 index 000000000000..5173ac43bd96 --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_similarity_with_model_param/out.sql @@ -0,0 +1,8 @@ +SELECT + AI.SIMILARITY( + content1 => `string_col`, + content2 => `string_col`, + endpoint => 'text-embedding-005', + model_params => JSON '{"outputDimensionality": 256}' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 64a5a94c9e7f..8237d175e175 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -326,3 +326,55 @@ def test_ai_score(scalar_types_df: dataframe.DataFrame, snapshot, connection_id) ) snapshot.assert_match(sql, "out.sql") + + +@pytest.mark.parametrize("connection_id", [None, CONNECTION_ID]) +def test_ai_similarity(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): + col_name = "string_col" + + op = ops.AISimilarity( + endpoint="text-embedding-005", + model=None, + model_params=None, + connection_id=connection_id, + ) + + sql = utils._apply_ops_to_sql( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_similarity_with_model(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AISimilarity( + endpoint=None, + model="embeddinggemma-300m", + model_params=None, + connection_id=None, + ) + + sql = utils._apply_ops_to_sql( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_similarity_with_model_param(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AISimilarity( + endpoint="text-embedding-005", + model=None, + model_params=json.dumps({"outputDimensionality": 256}), + connection_id=None, + ) + + sql = utils._apply_ops_to_sql( + scalar_types_df, [op.as_expr(col_name, col_name)], ["result"] + ) + + snapshot.assert_match(sql, "out.sql") diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 4879f576f0e1..1f834e0ca769 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -1143,6 +1143,9 @@ def visit_AIClassify(self, op, **kwargs): def visit_AIScore(self, op, **kwargs): return sge.func("AI.SCORE", *self._compile_ai_args(**kwargs)) + def visit_AISimilarity(self, op, **kwargs): + return sge.func("AI.SIMILARITY", *self._compile_ai_args(**kwargs)) + def _compile_ai_args(self, **kwargs): args = [] diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index e50f6fe893ab..8db351d91740 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -146,6 +146,20 @@ class AIScore(Value): shape = rlz.shape_like("prompt") + +@public +class AISimilarity(Value): + """Calculate the similarity between two contents""" + + content1: Value + content2: Value + endpoint: Optional[Value[dt.String]] + model: Optional[Value[dt.String]] + model_params: Optional[Value[dt.String]] + connection_id: Optional[Value[dt.String]] + + shape = rlz.shape_like("content1") + @attribute def dtype(self) -> dt.Struct: return dt.float64 From fa8def80c427c2ca52c9b67247afb22da9b07d0b Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Thu, 23 Apr 2026 00:02:09 +0000 Subject: [PATCH 2/2] fix lint --- .../core/compile/ibis_compiler/scalar_op_registry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 5ccc0e9d1a85..36594b6fb1d3 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1997,8 +1997,8 @@ def ai_similarity( content1: ibis_types.Value, content2: ibis_types.Value, op: ops.AISimilarity ) -> ibis_types.Value: return ai_ops.AISimilarity( - content1, # type: ignore - content2, # type: ignore + content1, # type: ignore + content2, # type: ignore op.endpoint, # type: ignore op.model, # type: ignore op.model_params, # type: ignore