From 541ea8fa57df0a27040716f9971bf1782a19221e Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Tue, 21 Apr 2026 20:44:47 +0000 Subject: [PATCH 1/5] feat(bigframes): implement ai.embed --- .../bigframes/bigquery/_operations/ai.py | 55 +++++++++++++++ packages/bigframes/bigframes/bigquery/ai.py | 2 + .../ibis_compiler/scalar_op_registry.py | 14 ++++ .../compile/sqlglot/expressions/ai_ops.py | 7 ++ .../bigframes/operations/__init__.py | 2 + .../bigframes/bigframes/operations/ai_ops.py | 21 ++++++ .../test_ai_ops/test_ai_embed/out.sql | 3 + .../test_ai_embed_with_connection_id/out.sql | 7 ++ .../test_ai_embed_with_model/out.sql | 3 + .../out.sql | 7 ++ .../sqlglot/expressions/test_ai_ops.py | 70 +++++++++++++++++++ .../sql/compilers/bigquery/__init__.py | 3 + .../ibis/expr/operations/ai_ops.py | 23 ++++++ 13 files changed, 217 insertions(+) create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed/out.sql create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model_param_and_title/out.sql diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index 7a509d4f95ff..eac1e735aef4 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -705,6 +705,61 @@ def generate_table( return session.read_gbq_query(query) +@log_adapter.method_logger(custom_base_name="bigquery_ai") +def embed( + content: str | series.Series, + *, + endpoint: str | None = None, + model: str | None = None, + task_type: ( + Literal[ + "retrieval_query", + "retrieval_document", + "semantic_similarity", + "classification", + "clustering", + "question_answering", + "fact_verification", + "code_retrieval_query", + ] + | None + ) = None, + title: str | None = None, + model_params: Mapping[Any, Any] | None = None, + connection_id: str = None, +) -> series.Series: + """ + Creates embeddings from text or image data in BigQuery. + """ + + if model is not None: + if any([x is not None for x in [endpoint, title, model_params, connection_id]]): + raise ValueError( + "You cannot specify endpoint, title, model_params, or connection_id when the model is set." + ) + elif endpoint is None: + raise ValueError("You must specify exactly one of 'endpoint' or 'model' argument.") + + if title is not None and task_type != "retrieval_document": + raise ValueError("You can only use 'title' parameter if you specify retrieval_document for the task_type value.") + + operator = ai_ops.AIEmbed( + endpoint=endpoint, + model=model, + task_type=task_type, + title=title, + model_params=json.dumps(model_params) if model_params else None, + connection_id=connection_id + ) + + if isinstance(content, str): + return series.Series([content])._apply_unary_op(operator) + elif isinstance(content, series.Series): + return content._apply_unary_op(operator) + else: + raise ValueError(f"Unsupported 'content' parameter type: {type(content)}") + + @log_adapter.method_logger(custom_base_name="bigquery_ai") def if_( prompt: PROMPT_TYPE, diff --git a/packages/bigframes/bigframes/bigquery/ai.py b/packages/bigframes/bigframes/bigquery/ai.py index 25a7df778127..5f161632d98c 100644 --- a/packages/bigframes/bigframes/bigquery/ai.py +++ b/packages/bigframes/bigframes/bigquery/ai.py @@ -58,6 +58,7 @@ from bigframes.bigquery._operations.ai import ( classify, + embed, forecast, generate, generate_bool, @@ -72,6 +73,7 @@ __all__ = [ "classify", + "embed", "forecast", "generate", "generate_bool", diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 26ba0d8cb4b4..6e24d65f4ab1 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1964,6 +1964,20 @@ def ai_generate_double( op.model_params, # type: ignore ).to_expr() +@scalar_op_compiler.register_nary_op(ops.AIEmbed, pass_op=True) +def ai_embed( + value: ibis_types.Value, op: ops.AIEmbed +) -> ibis_types.StructValue: + return ai_ops.AIEmbed( + value, # type: ignore + connection_id=op.connection_id, # type: ignore + endpoint=op.endpoint, # type: ignore + model=op.model, # type: ignore + task_type=op.task_type.upper() if op.task_type is not None else None, # type: ignore + title=op.title, # type: ignore + model_params=op.model_params, # type: ignore + ).to_expr() + @scalar_op_compiler.register_nary_op(ops.AIIf, pass_op=True) def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue: diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py index eaeeb0b3a56d..7ed1e36354ba 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -23,6 +23,7 @@ from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr register_nary_op = expression_compiler.expression_compiler.register_nary_op +register_unary_op = expression_compiler.expression_compiler.register_unary_op @register_nary_op(ops.AIGenerate, pass_op=True) @@ -52,6 +53,12 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression: return sge.func("AI.GENERATE_DOUBLE", *args) +@register_unary_op(ops.AIEmbed, pass_op=True) +def _(expr: TypedExpr, op: ops.AIEmbed) -> sge.Expression: + args = [expr.expr] + _construct_named_args(op) + + return sge.func("AI.EMBED", *args) + @register_nary_op(ops.AIIf, pass_op=True) def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression: diff --git a/packages/bigframes/bigframes/operations/__init__.py b/packages/bigframes/bigframes/operations/__init__.py index f35a07271bd7..f4c138a5527a 100644 --- a/packages/bigframes/bigframes/operations/__init__.py +++ b/packages/bigframes/bigframes/operations/__init__.py @@ -16,6 +16,7 @@ from bigframes.operations.ai_ops import ( AIClassify, + AIEmbed, AIGenerate, AIGenerateBool, AIGenerateDouble, @@ -434,6 +435,7 @@ "AIGenerateBool", "AIGenerateDouble", "AIGenerateInt", + "AIEmbed", "AIIf", "AIScore", # Numpy ops mapping diff --git a/packages/bigframes/bigframes/operations/ai_ops.py b/packages/bigframes/bigframes/operations/ai_ops.py index b20314fe2321..d6ed5c49590c 100644 --- a/packages/bigframes/bigframes/operations/ai_ops.py +++ b/packages/bigframes/bigframes/operations/ai_ops.py @@ -117,6 +117,27 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT ) ) +@dataclasses.dataclass(frozen=True) +class AIEmbed(base_ops.UnaryOp): + name: ClassVar[str] = "ai_embed" + + endpoint: str | None + model: str | None + task_type: str | None + title: str | None + model_params: str | None + connection_id: str | None + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + return pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.list_(pa.float64())), + pa.field("status", pa.string()) + ) + ) + ) + @dataclasses.dataclass(frozen=True) class AIIf(base_ops.NaryOp): diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed/out.sql new file mode 100644 index 000000000000..9c18a7cd532f --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed/out.sql @@ -0,0 +1,3 @@ +SELECT + AI.EMBED(`string_col`, endpoint => 'text-embedding-005') AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql new file mode 100644 index 000000000000..997b0cbd6a45 --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql @@ -0,0 +1,7 @@ +SELECT + AI.EMBED( + `string_col`, + connection_id => 'bigframes-dev.us.bigframes-default-connection', + endpoint => 'text-embedding-005' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql new file mode 100644 index 000000000000..5c36c8f484db --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql @@ -0,0 +1,3 @@ +SELECT + AI.EMBED(`string_col`) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model_param_and_title/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model_param_and_title/out.sql new file mode 100644 index 000000000000..844bddbd15e4 --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model_param_and_title/out.sql @@ -0,0 +1,7 @@ +SELECT + AI.EMBED( + `string_col`, + endpoint => 'text-embedding-005', + model_params => JSON '{"outputDimensionality": 256}' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 64a5a94c9e7f..9a2133eadca0 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -281,6 +281,76 @@ def test_ai_generate_double_with_model_param( snapshot.assert_match(sql, "out.sql") +def test_ai_embed(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIEmbed( + endpoint="text-embedding-005", + model=None, + task_type=None, + title=None, + model_params=None, + connection_id=None, + ) + + sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_embed_with_connection_id(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIEmbed( + endpoint="text-embedding-005", + model=None, + task_type=None, + title=None, + model_params=None, + connection_id=CONNECTION_ID, + ) + + sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_embed_with_model(scalar_types_df: dataframe.DataFrame, snapshot): + col_name = "string_col" + + op = ops.AIEmbed( + endpoint=None, + model="project.dataset.my_model", + task_type=None, + title=None, + model_params=None, + connection_id=None, + ) + + sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) + + snapshot.assert_match(sql, "out.sql") + + +def test_ai_embed_with_model_param_and_title( + scalar_types_df: dataframe.DataFrame, snapshot +): + col_name = "string_col" + + op = ops.AIEmbed( + endpoint="text-embedding-005", + model=None, + task_type="retrieval_document", + title="My Document", + model_params=json.dumps({"outputDimensionality": 256}), + connection_id=None, + ) + + sql = utils._apply_ops_to_sql(scalar_types_df, [op.as_expr(col_name)], ["result"]) + + snapshot.assert_match(sql, "out.sql") + + @pytest.mark.parametrize("connection_id", [None, CONNECTION_ID]) def test_ai_if(scalar_types_df: dataframe.DataFrame, snapshot, connection_id): col_name = "string_col" diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index 4879f576f0e1..c3b481dd9016 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -1133,6 +1133,9 @@ def visit_AIGenerateInt(self, op, **kwargs): def visit_AIGenerateDouble(self, op, **kwargs): return sge.func("AI.GENERATE_DOUBLE", *self._compile_ai_args(**kwargs)) + + def visit_AIEmbed(self, op, **kwargs): + return sge.func("AI.EMBED", *self._compile_ai_args(**kwargs)) def visit_AIIf(self, op, **kwargs): return sge.func("AI.IF", *self._compile_ai_args(**kwargs)) diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index e50f6fe893ab..9f655dc250e7 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -107,6 +107,29 @@ def dtype(self) -> dt.Struct: ) ) +@public +class AIEmbed(Value): + """Generate doubles based on the prompt""" + + content: Value + connection_id: Optional[Value[dt.String]] + endpoint: Optional[Value[dt.String]] + model: Optional[Value[dt.String]] + task_type: Optional[Value[dt.String]] + title: Optional[Value[dt.String]] + model_params: Optional[Value[dt.String]] + + shape = rlz.shape_like("content") + + @attribute + def dtype(self) -> dt.Struct: + return dt.Struct.from_tuples( + ( + ("result", dt.Array(dt.float64)), + ("status", dt.string), + ) + ) + @public class AIIf(Value): From 7d0eaae647e3e0bc6f545a2fff3ae9b00487aeb1 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 22 Apr 2026 18:06:14 +0000 Subject: [PATCH 2/5] add doc and tests --- .../bigframes/bigquery/_operations/ai.py | 50 ++++++++++++++- .../tests/system/small/bigquery/test_ai.py | 63 +++++++++++++++++++ .../test_ai_embed_with_connection_id/out.sql | 4 +- .../test_ai_embed_with_model/out.sql | 2 +- .../out.sql | 2 + .../out.sql | 9 +++ .../sqlglot/expressions/test_ai_ops.py | 4 +- .../ibis/expr/operations/ai_ops.py | 2 +- 8 files changed, 128 insertions(+), 8 deletions(-) create mode 100644 packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_task_type_and_title/out.sql diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index eac1e735aef4..5a347572978f 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -726,14 +726,60 @@ def embed( ) = None, title: str | None = None, model_params: Mapping[Any, Any] | None = None, - connection_id: str = None, + connection_id: str | None = None, ) -> series.Series: """ Creates embeddings from text or image data in BigQuery. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> bbq.ai.embed("dog", endpoint="text-embedding-005") # doctest: +SKIP + 0 {'result': array([ 1.78243860e-03, -1.10658340... + + >>> s = bpd.Series(['dog']) # doctest: +SKIP + >>> bbq.ai.embed(s, endpoint='text-embedding-005') # doctest: +SKIP + 0 {'result': array([ 1.78243860e-03, -1.10658340... + + Args: + content (str | Series): + A string literal or a Series that provides the text or image to embed. + endpoint (str, optional): + A string value that specifies a supported Vertex AI embedding model endpoint to use. + The endpoint value that you specify must include the model version, for example, + `"text-embedding-005"`. If you specify this parameter, you can't specify the + `model` parameter. + model (str, optional): + A string value that specifies a built-in embedding model. The only supported value is + `"embeddinggemma-300m"`. If you specify this parameter, you can't specify the `endpoint`, + `title`, `model_params`, or `connection_id` parameters. + task_type (str, optional): + A string literal that specifies the intended downstream application to help the model + produce better quality embeddings. Accepts `"retrieval_query"`, `"retrieval_document"`, + `"semantic_similarity"`, `"classification"`, `"clustering"`, `"question_answering"`, + `"fact_verification"`, `"code_retrieval_query"`. + title (str, optional): + A string value that specifies the document title, which the model uses to improve + embedding quality. You can only use this parameter if you specify `"retrieval_document"` + for the `task_type` value. + model_params (Mapping[Any, Any], optional): + A JSON literal that provides additional parameters to the model. For example, + `{"outputDimensionality": 768}` lets you specify the number of dimensions to use when + generating embeddings. + connection_id (str, optional): + A STRING value specifying the connection to use to communicate with the model, in the + format `PROJECT_ID.LOCATION.CONNECTION_ID`. For example, `myproject.us.myconnection`. + If not provided, the query uses your end-user credential. + + Returns: + bigframes.series.Series: A new struct Series with the result data. The struct contains these fields: + * "result": an ARRAY value containing the generated embeddings. + * "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful. """ if model is not None: - if any([x is not None for x in [endpoint, title, model_params, connection_id]]): + if any(x is not None for x in [endpoint, title, model_params, connection_id]): raise ValueError( "You cannot specify endpoint, title, model_params, or connection_id when the model is set." ) diff --git a/packages/bigframes/tests/system/small/bigquery/test_ai.py b/packages/bigframes/tests/system/small/bigquery/test_ai.py index 16e9cca9f136..e64acdb64c0d 100644 --- a/packages/bigframes/tests/system/small/bigquery/test_ai.py +++ b/packages/bigframes/tests/system/small/bigquery/test_ai.py @@ -255,6 +255,69 @@ def test_ai_generate_double_multi_model(session): ) +def test_ai_embed_series_content(session): + content = bpd.Series(["dog"], session=session) + + result = bbq.ai.embed(content, endpoint="text-embedding-005") + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.list_(pa.float64())), + pa.field("status", pa.string()), + ) + ) + ) + + +def test_ai_embed_string_content(session): + with mock.patch( + "bigframes.core.global_session.get_global_session" + ) as mock_get_session: + mock_get_session.return_value = session + + result = bbq.ai.embed("dog", endpoint="text-embedding-005") + + assert _contains_no_nulls(result) + assert result.dtype == pd.ArrowDtype( + pa.struct( + ( + pa.field("result", pa.list_(pa.float64())), + pa.field("status", pa.string()), + ) + ) + ) + + +def test_ai_embed_no_endpoint_or_model_raises_error(session): + content = bpd.Series(["dog"], session=session) + + with pytest.raises(ValueError): + bbq.ai.embed(content) + + +def test_ai_embed_both_model_and_endpoint_are_set_raises_error(session): + content = bpd.Series(["dog"], session=session) + + with pytest.raises(ValueError): + bbq.ai.embed( + content, endpoint="text-embedding-005", model="embeddinggemma-300m model" + ) + + +def test_ai_embed_title_and_task_type_mismatch_raises_error(session): + content = bpd.Series(["dog"], session=session) + + with pytest.raises(ValueError): + bbq.ai.embed( + content, + endpoint="text-embedding-005", + title="my title", + task_type="text_similarity", + ) + + def test_ai_if(session): s1 = bpd.Series(["apple", "bear"], session=session) s2 = bpd.Series(["fruit", "tree"], session=session) diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql index 997b0cbd6a45..0968a101b22a 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_connection_id/out.sql @@ -1,7 +1,7 @@ SELECT AI.EMBED( `string_col`, - connection_id => 'bigframes-dev.us.bigframes-default-connection', - endpoint => 'text-embedding-005' + endpoint => 'text-embedding-005', + connection_id => 'bigframes-dev.us.bigframes-default-connection' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql index 5c36c8f484db..4c3c76f87b61 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model/out.sql @@ -1,3 +1,3 @@ SELECT - AI.EMBED(`string_col`) AS `result` + AI.EMBED(`string_col`, model => 'embeddinggemma-300m') AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model_param_and_title/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model_param_and_title/out.sql index 844bddbd15e4..873db838682a 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model_param_and_title/out.sql +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_model_param_and_title/out.sql @@ -2,6 +2,8 @@ SELECT AI.EMBED( `string_col`, endpoint => 'text-embedding-005', + task_type => 'retrieval_document', + title => 'My Document', model_params => JSON '{"outputDimensionality": 256}' ) AS `result` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_task_type_and_title/out.sql b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_task_type_and_title/out.sql new file mode 100644 index 000000000000..873db838682a --- /dev/null +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/snapshots/test_ai_ops/test_ai_embed_with_task_type_and_title/out.sql @@ -0,0 +1,9 @@ +SELECT + AI.EMBED( + `string_col`, + endpoint => 'text-embedding-005', + task_type => 'retrieval_document', + title => 'My Document', + model_params => JSON '{"outputDimensionality": 256}' + ) AS `result` +FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0` \ No newline at end of file diff --git a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py index 9a2133eadca0..a7aacb4799ba 100644 --- a/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py +++ b/packages/bigframes/tests/unit/core/compile/sqlglot/expressions/test_ai_ops.py @@ -320,7 +320,7 @@ def test_ai_embed_with_model(scalar_types_df: dataframe.DataFrame, snapshot): op = ops.AIEmbed( endpoint=None, - model="project.dataset.my_model", + model="embeddinggemma-300m", task_type=None, title=None, model_params=None, @@ -332,7 +332,7 @@ def test_ai_embed_with_model(scalar_types_df: dataframe.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") -def test_ai_embed_with_model_param_and_title( +def test_ai_embed_with_task_type_and_title( scalar_types_df: dataframe.DataFrame, snapshot ): col_name = "string_col" diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 9f655dc250e7..49cabc01ed36 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -109,7 +109,7 @@ def dtype(self) -> dt.Struct: @public class AIEmbed(Value): - """Generate doubles based on the prompt""" + """Create embeddings from text or image data.""" content: Value connection_id: Optional[Value[dt.String]] From de4c063a7f167145733347a48f0fed11c23b005b Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 22 Apr 2026 18:21:51 +0000 Subject: [PATCH 3/5] fix lint --- .../bigframes/bigframes/bigquery/_operations/ai.py | 10 +++++++--- .../core/compile/ibis_compiler/scalar_op_registry.py | 11 +++++------ .../core/compile/sqlglot/expressions/ai_ops.py | 1 + packages/bigframes/bigframes/operations/ai_ops.py | 3 ++- .../ibis/backends/sql/compilers/bigquery/__init__.py | 2 +- .../bigframes_vendored/ibis/expr/operations/ai_ops.py | 1 + 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index 5a347572978f..7162c4b3d881 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -784,10 +784,14 @@ def embed( "You cannot specify endpoint, title, model_params, or connection_id when the model is set." ) elif endpoint is None: - raise ValueError("You must specify exactly one of 'endpoint' or 'model' argument.") + raise ValueError( + "You must specify exactly one of 'endpoint' or 'model' argument." + ) if title is not None and task_type != "retrieval_document": - raise ValueError("You can only use 'title' parameter if you specify retrieval_document for the task_type value.") + raise ValueError( + "You can only use 'title' parameter if you specify retrieval_document for the task_type value." + ) operator = ai_ops.AIEmbed( endpoint=endpoint, @@ -795,7 +799,7 @@ def embed( task_type=task_type, title=title, model_params=json.dumps(model_params) if model_params else None, - connection_id=connection_id + connection_id=connection_id, ) if isinstance(content, str): diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index 6e24d65f4ab1..e4b032dc49c0 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1964,17 +1964,16 @@ def ai_generate_double( op.model_params, # type: ignore ).to_expr() + @scalar_op_compiler.register_nary_op(ops.AIEmbed, pass_op=True) -def ai_embed( - value: ibis_types.Value, op: ops.AIEmbed -) -> ibis_types.StructValue: +def ai_embed(value: ibis_types.Value, op: ops.AIEmbed) -> ibis_types.StructValue: return ai_ops.AIEmbed( - value, # type: ignore + value, # type: ignore connection_id=op.connection_id, # type: ignore endpoint=op.endpoint, # type: ignore model=op.model, # type: ignore - task_type=op.task_type.upper() if op.task_type is not None else None, # type: ignore - title=op.title, # type: ignore + task_type=op.task_type.upper() if op.task_type is not None else None, # type: ignore + title=op.title, # type: ignore model_params=op.model_params, # type: ignore ).to_expr() diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 7ed1e36354ba..360033c12d5e 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -53,6 +53,7 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression: return sge.func("AI.GENERATE_DOUBLE", *args) + @register_unary_op(ops.AIEmbed, pass_op=True) def _(expr: TypedExpr, op: ops.AIEmbed) -> sge.Expression: args = [expr.expr] + _construct_named_args(op) diff --git a/packages/bigframes/bigframes/operations/ai_ops.py b/packages/bigframes/bigframes/operations/ai_ops.py index d6ed5c49590c..2878b6584447 100644 --- a/packages/bigframes/bigframes/operations/ai_ops.py +++ b/packages/bigframes/bigframes/operations/ai_ops.py @@ -117,6 +117,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT ) ) + @dataclasses.dataclass(frozen=True) class AIEmbed(base_ops.UnaryOp): name: ClassVar[str] = "ai_embed" @@ -133,7 +134,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT pa.struct( ( pa.field("result", pa.list_(pa.float64())), - pa.field("status", pa.string()) + pa.field("status", pa.string()), ) ) ) diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index c3b481dd9016..9a6906f4e5cc 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -1133,7 +1133,7 @@ def visit_AIGenerateInt(self, op, **kwargs): def visit_AIGenerateDouble(self, op, **kwargs): return sge.func("AI.GENERATE_DOUBLE", *self._compile_ai_args(**kwargs)) - + def visit_AIEmbed(self, op, **kwargs): return sge.func("AI.EMBED", *self._compile_ai_args(**kwargs)) diff --git a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py index 49cabc01ed36..642460da5291 100644 --- a/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py +++ b/packages/bigframes/third_party/bigframes_vendored/ibis/expr/operations/ai_ops.py @@ -107,6 +107,7 @@ def dtype(self) -> dt.Struct: ) ) + @public class AIEmbed(Value): """Create embeddings from text or image data.""" From 0000be15c5ccef2a000b3bacf065d971c7840033 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 22 Apr 2026 18:56:14 +0000 Subject: [PATCH 4/5] fix mypy --- .../core/compile/ibis_compiler/scalar_op_registry.py | 2 +- .../bigframes/core/compile/sqlglot/expressions/ai_ops.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py index e4b032dc49c0..75bb134f1df3 100644 --- a/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py +++ b/packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_registry.py @@ -1965,7 +1965,7 @@ def ai_generate_double( ).to_expr() -@scalar_op_compiler.register_nary_op(ops.AIEmbed, pass_op=True) +@scalar_op_compiler.register_unary_op(ops.AIEmbed, pass_op=True) def ai_embed(value: ibis_types.Value, op: ops.AIEmbed) -> ibis_types.StructValue: return ai_ops.AIEmbed( value, # type: ignore diff --git a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py index 360033c12d5e..8dbb298a1bab 100644 --- a/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py +++ b/packages/bigframes/bigframes/core/compile/sqlglot/expressions/ai_ops.py @@ -15,6 +15,7 @@ from __future__ import annotations from dataclasses import asdict +from typing import Any import bigframes_vendored.sqlglot.expressions as sge @@ -56,7 +57,7 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression: @register_unary_op(ops.AIEmbed, pass_op=True) def _(expr: TypedExpr, op: ops.AIEmbed) -> sge.Expression: - args = [expr.expr] + _construct_named_args(op) + args: list[Any] = [expr.expr] + _construct_named_args(op) return sge.func("AI.EMBED", *args) @@ -102,7 +103,7 @@ def _construct_prompt( return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt)) -def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]: +def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]: args = [] op_args = asdict(op) From 39ba275571f18767aff59de00d52278ff1c42ce5 Mon Sep 17 00:00:00 2001 From: Shenyang Cai Date: Wed, 22 Apr 2026 21:20:36 +0000 Subject: [PATCH 5/5] support pandas series --- packages/bigframes/bigframes/bigquery/_operations/ai.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/bigframes/bigframes/bigquery/_operations/ai.py b/packages/bigframes/bigframes/bigquery/_operations/ai.py index 7162c4b3d881..c6115c6d6a5e 100644 --- a/packages/bigframes/bigframes/bigquery/_operations/ai.py +++ b/packages/bigframes/bigframes/bigquery/_operations/ai.py @@ -707,7 +707,7 @@ def generate_table( @log_adapter.method_logger(custom_base_name="bigquery_ai") def embed( - content: str | series.Series, + content: str | series.Series | pd.Series, *, endpoint: str | None = None, model: str | None = None, @@ -744,7 +744,7 @@ def embed( Args: content (str | Series): - A string literal or a Series that provides the text or image to embed. + A string literal or a Series (either BigFrames series or pandas Series) that provides the text or image to embed. endpoint (str, optional): A string value that specifies a supported Vertex AI embedding model endpoint to use. The endpoint value that you specify must include the model version, for example, @@ -804,6 +804,8 @@ def embed( if isinstance(content, str): return series.Series([content])._apply_unary_op(operator) + elif isinstance(content, pd.Series): + return series.Series(content)._apply_unary_op(operator) elif isinstance(content, series.Series): return content._apply_unary_op(operator) else: