Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions packages/bigframes/bigframes/bigquery/_operations/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,113 @@ def generate_table(
return session.read_gbq_query(query)


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def embed(
content: str | series.Series | pd.Series,
*,
endpoint: str | None = None,
model: str | None = None,
task_type: (
Literal[
"retrieval_query",
"retrieval_document",
"semantic_similarity",
"classification",
"clustering",
"question_answering",
"fact_verification",
"code_retrieval_query",
]
| None
) = None,
title: str | None = None,
model_params: Mapping[Any, Any] | None = None,
connection_id: str | None = None,
) -> series.Series:
"""
Creates embeddings from text or image data in BigQuery.

**Examples:**

>>> import bigframes.pandas as bpd
>>> import bigframes.bigquery as bbq
>>> bbq.ai.embed("dog", endpoint="text-embedding-005") # doctest: +SKIP
0 {'result': array([ 1.78243860e-03, -1.10658340...

>>> s = bpd.Series(['dog']) # doctest: +SKIP
>>> bbq.ai.embed(s, endpoint='text-embedding-005') # doctest: +SKIP
0 {'result': array([ 1.78243860e-03, -1.10658340...

Args:
content (str | Series):
A string literal or a Series (either BigFrames series or pandas Series) that provides the text or image to embed.
endpoint (str, optional):
A string value that specifies a supported Vertex AI embedding model endpoint to use.
The endpoint value that you specify must include the model version, for example,
`"text-embedding-005"`. If you specify this parameter, you can't specify the
`model` parameter.
model (str, optional):
A string value that specifies a built-in embedding model. The only supported value is
`"embeddinggemma-300m"`. If you specify this parameter, you can't specify the `endpoint`,
`title`, `model_params`, or `connection_id` parameters.
task_type (str, optional):
A string literal that specifies the intended downstream application to help the model
produce better quality embeddings. Accepts `"retrieval_query"`, `"retrieval_document"`,
`"semantic_similarity"`, `"classification"`, `"clustering"`, `"question_answering"`,
`"fact_verification"`, `"code_retrieval_query"`.
title (str, optional):
A string value that specifies the document title, which the model uses to improve
embedding quality. You can only use this parameter if you specify `"retrieval_document"`
for the `task_type` value.
model_params (Mapping[Any, Any], optional):
A JSON literal that provides additional parameters to the model. For example,
`{"outputDimensionality": 768}` lets you specify the number of dimensions to use when
generating embeddings.
connection_id (str, optional):
A STRING value specifying the connection to use to communicate with the model, in the
format `PROJECT_ID.LOCATION.CONNECTION_ID`. For example, `myproject.us.myconnection`.
If not provided, the query uses your end-user credential.

Returns:
bigframes.series.Series: A new struct Series with the result data. The struct contains these fields:
* "result": an ARRAY<FLOAT64> value containing the generated embeddings.
* "status": a STRING value that contains the API response status for the corresponding row. This value is empty if the operation was successful.
"""

if model is not None:
if any(x is not None for x in [endpoint, title, model_params, connection_id]):
raise ValueError(
"You cannot specify endpoint, title, model_params, or connection_id when the model is set."
)
elif endpoint is None:
raise ValueError(
"You must specify exactly one of 'endpoint' or 'model' argument."
)

if title is not None and task_type != "retrieval_document":
raise ValueError(
"You can only use 'title' parameter if you specify retrieval_document for the task_type value."
)

operator = ai_ops.AIEmbed(
endpoint=endpoint,
model=model,
task_type=task_type,
title=title,
model_params=json.dumps(model_params) if model_params else None,
connection_id=connection_id,
)

if isinstance(content, str):
return series.Series([content])._apply_unary_op(operator)
elif isinstance(content, pd.Series):
return series.Series(content)._apply_unary_op(operator)
elif isinstance(content, series.Series):
return content._apply_unary_op(operator)
else:
raise ValueError(f"Unsupported 'content' parameter type: {type(content)}")


@log_adapter.method_logger(custom_base_name="bigquery_ai")
def if_(
prompt: PROMPT_TYPE,
Expand Down
2 changes: 2 additions & 0 deletions packages/bigframes/bigframes/bigquery/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

from bigframes.bigquery._operations.ai import (
classify,
embed,
forecast,
generate,
generate_bool,
Expand All @@ -72,6 +73,7 @@

__all__ = [
"classify",
"embed",
"forecast",
"generate",
"generate_bool",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,19 @@ def ai_generate_double(
).to_expr()


@scalar_op_compiler.register_unary_op(ops.AIEmbed, pass_op=True)
def ai_embed(value: ibis_types.Value, op: ops.AIEmbed) -> ibis_types.StructValue:
return ai_ops.AIEmbed(
value, # type: ignore
connection_id=op.connection_id, # type: ignore
endpoint=op.endpoint, # type: ignore
model=op.model, # type: ignore
task_type=op.task_type.upper() if op.task_type is not None else None, # type: ignore
title=op.title, # type: ignore
model_params=op.model_params, # type: ignore
).to_expr()


@scalar_op_compiler.register_nary_op(ops.AIIf, pass_op=True)
def ai_if(*values: ibis_types.Value, op: ops.AIIf) -> ibis_types.StructValue:
return ai_ops.AIIf(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from dataclasses import asdict
from typing import Any

import bigframes_vendored.sqlglot.expressions as sge

Expand All @@ -23,6 +24,7 @@
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr

register_nary_op = expression_compiler.expression_compiler.register_nary_op
register_unary_op = expression_compiler.expression_compiler.register_unary_op


@register_nary_op(ops.AIGenerate, pass_op=True)
Expand Down Expand Up @@ -53,6 +55,13 @@ def _(*exprs: TypedExpr, op: ops.AIGenerateDouble) -> sge.Expression:
return sge.func("AI.GENERATE_DOUBLE", *args)


@register_unary_op(ops.AIEmbed, pass_op=True)
def _(expr: TypedExpr, op: ops.AIEmbed) -> sge.Expression:
args: list[Any] = [expr.expr] + _construct_named_args(op)

return sge.func("AI.EMBED", *args)


@register_nary_op(ops.AIIf, pass_op=True)
def _(*exprs: TypedExpr, op: ops.AIIf) -> sge.Expression:
args = [_construct_prompt(exprs, op.prompt_context)] + _construct_named_args(op)
Expand Down Expand Up @@ -94,7 +103,7 @@ def _construct_prompt(
return sge.Kwarg(this=param_name, expression=sge.Tuple(expressions=prompt))


def _construct_named_args(op: ops.NaryOp) -> list[sge.Kwarg]:
def _construct_named_args(op: ops.ScalarOp) -> list[sge.Kwarg]:
args = []

op_args = asdict(op)
Expand Down
2 changes: 2 additions & 0 deletions packages/bigframes/bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from bigframes.operations.ai_ops import (
AIClassify,
AIEmbed,
AIGenerate,
AIGenerateBool,
AIGenerateDouble,
Expand Down Expand Up @@ -434,6 +435,7 @@
"AIGenerateBool",
"AIGenerateDouble",
"AIGenerateInt",
"AIEmbed",
"AIIf",
"AIScore",
# Numpy ops mapping
Expand Down
22 changes: 22 additions & 0 deletions packages/bigframes/bigframes/operations/ai_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,28 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
)


@dataclasses.dataclass(frozen=True)
class AIEmbed(base_ops.UnaryOp):
name: ClassVar[str] = "ai_embed"

endpoint: str | None
model: str | None
task_type: str | None
title: str | None
model_params: str | None
connection_id: str | None

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
return pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.list_(pa.float64())),
pa.field("status", pa.string()),
)
)
)


@dataclasses.dataclass(frozen=True)
class AIIf(base_ops.NaryOp):
name: ClassVar[str] = "ai_if"
Expand Down
63 changes: 63 additions & 0 deletions packages/bigframes/tests/system/small/bigquery/test_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,69 @@ def test_ai_generate_double_multi_model(session):
)


def test_ai_embed_series_content(session):
content = bpd.Series(["dog"], session=session)

result = bbq.ai.embed(content, endpoint="text-embedding-005")

assert _contains_no_nulls(result)
assert result.dtype == pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.list_(pa.float64())),
pa.field("status", pa.string()),
)
)
)


def test_ai_embed_string_content(session):
with mock.patch(
"bigframes.core.global_session.get_global_session"
) as mock_get_session:
mock_get_session.return_value = session

result = bbq.ai.embed("dog", endpoint="text-embedding-005")

assert _contains_no_nulls(result)
assert result.dtype == pd.ArrowDtype(
pa.struct(
(
pa.field("result", pa.list_(pa.float64())),
pa.field("status", pa.string()),
)
)
)


def test_ai_embed_no_endpoint_or_model_raises_error(session):
content = bpd.Series(["dog"], session=session)

with pytest.raises(ValueError):
bbq.ai.embed(content)


def test_ai_embed_both_model_and_endpoint_are_set_raises_error(session):
content = bpd.Series(["dog"], session=session)

with pytest.raises(ValueError):
bbq.ai.embed(
content, endpoint="text-embedding-005", model="embeddinggemma-300m model"
)


def test_ai_embed_title_and_task_type_mismatch_raises_error(session):
content = bpd.Series(["dog"], session=session)

with pytest.raises(ValueError):
bbq.ai.embed(
content,
endpoint="text-embedding-005",
title="my title",
task_type="text_similarity",
)


def test_ai_if(session):
s1 = bpd.Series(["apple", "bear"], session=session)
s2 = bpd.Series(["fruit", "tree"], session=session)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
AI.EMBED(`string_col`, endpoint => 'text-embedding-005') AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
AI.EMBED(
`string_col`,
endpoint => 'text-embedding-005',
connection_id => 'bigframes-dev.us.bigframes-default-connection'
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
AI.EMBED(`string_col`, model => 'embeddinggemma-300m') AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
SELECT
AI.EMBED(
`string_col`,
endpoint => 'text-embedding-005',
task_type => 'retrieval_document',
title => 'My Document',
model_params => JSON '{"outputDimensionality": 256}'
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
SELECT
AI.EMBED(
`string_col`,
endpoint => 'text-embedding-005',
task_type => 'retrieval_document',
title => 'My Document',
model_params => JSON '{"outputDimensionality": 256}'
) AS `result`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
Loading
Loading