Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit b952428

Browse files
fix: Respect remote function config changes even if logic unchanged (#2512)
1 parent 96597f0 commit b952428

File tree

30 files changed

+978
-752
lines changed

30 files changed

+978
-752
lines changed

bigframes/core/compile/ibis_compiler/ibis_compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def _replace_unsupported_ops(node: nodes.BigFrameNode):
8888
node = nodes.bottom_up(node, rewrites.rewrite_slice)
8989
node = nodes.bottom_up(node, rewrites.rewrite_timedelta_expressions)
9090
node = nodes.bottom_up(node, rewrites.rewrite_range_rolling)
91+
node = nodes.bottom_up(node, rewrites.lower_udfs)
9192
return node
9293

9394

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,8 @@ def timedelta_floor_op_impl(x: ibis_types.NumericValue):
10371037
@scalar_op_compiler.register_unary_op(ops.RemoteFunctionOp, pass_op=True)
10381038
def remote_function_op_impl(x: ibis_types.Value, op: ops.RemoteFunctionOp):
10391039
udf_sig = op.function_def.signature
1040-
ibis_py_sig = (udf_sig.py_input_types, udf_sig.py_output_type)
1040+
assert not udf_sig.is_virtual # should have been devirtualized in lowering pass
1041+
ibis_py_sig = (tuple(arg.py_type for arg in udf_sig.inputs), udf_sig.output.py_type)
10411042

10421043
@ibis_udf.scalar.builtin(
10431044
name=str(op.function_def.routine_ref), signature=ibis_py_sig
@@ -1056,7 +1057,8 @@ def binary_remote_function_op_impl(
10561057
x: ibis_types.Value, y: ibis_types.Value, op: ops.BinaryRemoteFunctionOp
10571058
):
10581059
udf_sig = op.function_def.signature
1059-
ibis_py_sig = (udf_sig.py_input_types, udf_sig.py_output_type)
1060+
assert not udf_sig.is_virtual # should have been devirtualized in lowering pass
1061+
ibis_py_sig = (tuple(arg.py_type for arg in udf_sig.inputs), udf_sig.output.py_type)
10601062

10611063
@ibis_udf.scalar.builtin(
10621064
name=str(op.function_def.routine_ref), signature=ibis_py_sig
@@ -1073,8 +1075,9 @@ def nary_remote_function_op_impl(
10731075
*operands: ibis_types.Value, op: ops.NaryRemoteFunctionOp
10741076
):
10751077
udf_sig = op.function_def.signature
1076-
ibis_py_sig = (udf_sig.py_input_types, udf_sig.py_output_type)
1077-
arg_names = tuple(arg.name for arg in udf_sig.input_types)
1078+
assert not udf_sig.is_virtual # should have been devirtualized in lowering pass
1079+
ibis_py_sig = (tuple(arg.py_type for arg in udf_sig.inputs), udf_sig.output.py_type)
1080+
arg_names = tuple(arg.name for arg in udf_sig.inputs)
10781081

10791082
@ibis_udf.scalar.builtin(
10801083
name=str(op.function_def.routine_ref),
@@ -1153,6 +1156,13 @@ def array_reduce_op_impl(x: ibis_types.Value, op: ops.ArrayReduceOp):
11531156
)
11541157

11551158

1159+
@scalar_op_compiler.register_unary_op(ops.ArrayMapOp, pass_op=True)
1160+
def array_map_op_impl(x: ibis_types.Value, op: ops.ArrayMapOp):
1161+
return typing.cast(ibis_types.ArrayValue, x).map(
1162+
lambda arr_vals: scalar_op_compiler.compile_row_op(op.map_op, (arr_vals,))
1163+
)
1164+
1165+
11561166
# JSON Ops
11571167
@scalar_op_compiler.register_binary_op(ops.JSONSet, pass_op=True)
11581168
def json_set_op_impl(x: ibis_types.Value, y: ibis_types.Value, op: ops.JSONSet):

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,4 +369,5 @@ def compile_aggregate(
369369
def _replace_unsupported_ops(node: nodes.BigFrameNode):
370370
node = nodes.bottom_up(node, rewrite.rewrite_slice)
371371
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
372+
node = nodes.bottom_up(node, rewrite.lower_udfs)
372373
return node

bigframes/core/compile/sqlglot/expressions/array_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,28 @@ def _(expr: TypedExpr, op: ops.ArrayReduceOp) -> sge.Expression:
7373
)
7474

7575

76+
@register_unary_op(ops.ArrayMapOp, pass_op=True)
77+
def _(expr: TypedExpr, op: ops.ArrayMapOp) -> sge.Expression:
78+
sub_expr = sg.to_identifier("bf_arr_map_uid")
79+
sub_type = dtypes.get_array_inner_type(expr.dtype)
80+
81+
# TODO: Expression should be provided instead of invoking compiler manually
82+
map_expr = expression_compiler.expression_compiler.compile_row_op(
83+
op.map_op, (TypedExpr(sub_expr, sub_type),)
84+
)
85+
86+
return sge.array(
87+
sge.select(map_expr)
88+
.from_(
89+
sge.Unnest(
90+
expressions=[expr.expr],
91+
alias=sge.TableAlias(columns=[sub_expr]),
92+
)
93+
)
94+
.subquery()
95+
)
96+
97+
7698
@register_unary_op(ops.ArraySliceOp, pass_op=True)
7799
def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
78100
if expr.dtype == dtypes.STRING_DTYPE:

bigframes/core/compile/sqlglot/sql/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from bigframes.core.compile.sqlglot.sql.base import (
1717
cast,
18-
escape_chars,
1918
identifier,
2019
is_null_literal,
2120
literal,
@@ -28,7 +27,6 @@
2827
__all__ = [
2928
# From base.py
3029
"cast",
31-
"escape_chars",
3230
"identifier",
3331
"is_null_literal",
3432
"literal",

bigframes/core/compile/sqlglot/sql/base.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -136,29 +136,6 @@ def table(table: bigquery.TableReference) -> sge.Table:
136136
)
137137

138138

139-
def escape_chars(value: str):
140-
"""Escapes all special characters"""
141-
# TODO: Reuse literal's escaping logic instead of re-implementing it here.
142-
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#string_and_bytes_literals
143-
trans_table = str.maketrans(
144-
{
145-
"\a": r"\a",
146-
"\b": r"\b",
147-
"\f": r"\f",
148-
"\n": r"\n",
149-
"\r": r"\r",
150-
"\t": r"\t",
151-
"\v": r"\v",
152-
"\\": r"\\",
153-
"?": r"\?",
154-
'"': r"\"",
155-
"'": r"\'",
156-
"`": r"\`",
157-
}
158-
)
159-
return value.translate(trans_table)
160-
161-
162139
def is_null_literal(expr: sge.Expression) -> bool:
163140
"""Checks if the given expression is a NULL literal."""
164141
if isinstance(expr, sge.Null):

bigframes/core/rewrite/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from bigframes.core.rewrite.select_pullup import defer_selection
2828
from bigframes.core.rewrite.slices import pull_out_limit, pull_up_limits, rewrite_slice
2929
from bigframes.core.rewrite.timedeltas import rewrite_timedelta_expressions
30+
from bigframes.core.rewrite.udfs import lower_udfs
3031
from bigframes.core.rewrite.windows import (
3132
pull_out_window_order,
3233
rewrite_range_rolling,
@@ -53,4 +54,5 @@
5354
"pull_out_window_order",
5455
"defer_selection",
5556
"simplify_complex_windows",
57+
"lower_udfs",
5658
]

bigframes/core/rewrite/udfs.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import dataclasses
17+
18+
from bigframes.core import bigframe_node, expression
19+
from bigframes.core.rewrite import op_lowering
20+
import bigframes.functions.udf_def as udf_def
21+
import bigframes.operations as ops
22+
23+
24+
@dataclasses.dataclass
25+
class LowerRemoteFunctionRule(op_lowering.OpLoweringRule):
26+
@property
27+
def op(self) -> type[ops.ScalarOp]:
28+
return ops.RemoteFunctionOp
29+
30+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
31+
assert isinstance(expr.op, ops.RemoteFunctionOp)
32+
func_def = expr.op.function_def
33+
devirtualized_expr = ops.RemoteFunctionOp(
34+
func_def.with_devirtualize(),
35+
apply_on_null=expr.op.apply_on_null,
36+
).as_expr(*expr.children)
37+
if isinstance(func_def.signature.output, udf_def.VirtualListTypeV1):
38+
return func_def.signature.output.out_expr(devirtualized_expr)
39+
else:
40+
return devirtualized_expr
41+
42+
43+
@dataclasses.dataclass
44+
class LowerBinaryRemoteFunctionRule(op_lowering.OpLoweringRule):
45+
@property
46+
def op(self) -> type[ops.ScalarOp]:
47+
return ops.BinaryRemoteFunctionOp
48+
49+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
50+
assert isinstance(expr.op, ops.BinaryRemoteFunctionOp)
51+
func_def = expr.op.function_def
52+
devirtualized_expr = ops.BinaryRemoteFunctionOp(
53+
func_def.with_devirtualize(),
54+
).as_expr(*expr.children)
55+
if isinstance(func_def.signature.output, udf_def.VirtualListTypeV1):
56+
return func_def.signature.output.out_expr(devirtualized_expr)
57+
else:
58+
return devirtualized_expr
59+
60+
61+
@dataclasses.dataclass
62+
class LowerNaryRemoteFunctionRule(op_lowering.OpLoweringRule):
63+
@property
64+
def op(self) -> type[ops.ScalarOp]:
65+
return ops.NaryRemoteFunctionOp
66+
67+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
68+
assert isinstance(expr.op, ops.NaryRemoteFunctionOp)
69+
func_def = expr.op.function_def
70+
devirtualized_expr = ops.NaryRemoteFunctionOp(
71+
func_def.with_devirtualize(),
72+
).as_expr(*expr.children)
73+
if isinstance(func_def.signature.output, udf_def.VirtualListTypeV1):
74+
return func_def.signature.output.out_expr(devirtualized_expr)
75+
else:
76+
return devirtualized_expr
77+
78+
79+
UDF_LOWERING_RULES = (
80+
LowerRemoteFunctionRule(),
81+
LowerBinaryRemoteFunctionRule(),
82+
LowerNaryRemoteFunctionRule(),
83+
)
84+
85+
86+
def lower_udfs(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode:
87+
return op_lowering.lower_ops(root, rules=UDF_LOWERING_RULES)

bigframes/core/sql/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,35 @@
4848
to_wkt = dumps
4949

5050

51+
def identifier(name: str) -> str:
52+
if len(name) > 256:
53+
raise ValueError("Identifier must be less than 256 characters")
54+
return f"`{escape_chars(name)}`"
55+
56+
57+
def escape_chars(value: str):
58+
"""Escapes all special characters"""
59+
# TODO: Reuse literal's escaping logic instead of re-implementing it here.
60+
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#string_and_bytes_literals
61+
trans_table = str.maketrans(
62+
{
63+
"\a": r"\a",
64+
"\b": r"\b",
65+
"\f": r"\f",
66+
"\n": r"\n",
67+
"\r": r"\r",
68+
"\t": r"\t",
69+
"\v": r"\v",
70+
"\\": r"\\",
71+
"?": r"\?",
72+
'"': r"\"",
73+
"'": r"\'",
74+
"`": r"\`",
75+
}
76+
)
77+
return value.translate(trans_table)
78+
79+
5180
def multi_literal(*values: Any):
5281
literal_strings = [sql.to_sql(sql.literal(i)) for i in values]
5382
return "(" + ", ".join(literal_strings) + ")"

bigframes/dataframe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4748,7 +4748,9 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
47484748
# compatible with the data types of the input params.
47494749
# 3. The order of the columns in the dataframe must correspond
47504750
# to the order of the input params in the function.
4751-
udf_input_dtypes = func.udf_def.signature.bf_input_types
4751+
udf_input_dtypes = tuple(
4752+
arg.bf_type for arg in func.udf_def.signature.inputs
4753+
)
47524754
if not args and len(udf_input_dtypes) != len(self.columns):
47534755
raise ValueError(
47544756
f"Parameter count mismatch: BigFrames BigQuery function"
@@ -4793,7 +4795,6 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
47934795
)
47944796
result_series.name = None
47954797

4796-
result_series = func._post_process_series(result_series)
47974798
return result_series
47984799

47994800
# At this point column-wise or element-wise bigquery function operation will

0 commit comments

Comments
 (0)