diff --git a/sdk/python/feast/infra/common/retrieval_task.py b/sdk/python/feast/infra/common/retrieval_task.py index a5b5583b3ce..960e0d34c49 100644 --- a/sdk/python/feast/infra/common/retrieval_task.py +++ b/sdk/python/feast/infra/common/retrieval_task.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from datetime import datetime -from typing import Union +from typing import Optional, Union import pandas as pd @@ -15,5 +15,5 @@ class HistoricalRetrievalTask: feature_view: Union[BatchFeatureView, StreamFeatureView] full_feature_name: bool registry: Registry - start_time: datetime - end_time: datetime + start_time: Optional[datetime] = None + end_time: Optional[datetime] = None diff --git a/sdk/python/feast/infra/compute_engines/feature_builder.py b/sdk/python/feast/infra/compute_engines/feature_builder.py index 324f82e7500..ceed3e2d4f3 100644 --- a/sdk/python/feast/infra/compute_engines/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/feature_builder.py @@ -72,10 +72,10 @@ def _should_validate(self): def build(self) -> ExecutionPlan: last_node = self.build_source_node() - # PIT join entities to the feature data, and perform filtering - if isinstance(self.task, HistoricalRetrievalTask): - last_node = self.build_join_node(last_node) + # Join entity_df with source if needed + last_node = self.build_join_node(last_node) + # PIT filter, TTL, and user-defined filter last_node = self.build_filter_node(last_node) if self._should_aggregate(): diff --git a/sdk/python/feast/infra/compute_engines/local/feature_builder.py b/sdk/python/feast/infra/compute_engines/local/feature_builder.py index bf755ed96d0..aee245da21c 100644 --- a/sdk/python/feast/infra/compute_engines/local/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/local/feature_builder.py @@ -2,7 +2,6 @@ from feast.infra.common.materialization_job import MaterializationTask from feast.infra.common.retrieval_task import HistoricalRetrievalTask -from feast.infra.compute_engines.dag.plan import ExecutionPlan from feast.infra.compute_engines.feature_builder import FeatureBuilder from feast.infra.compute_engines.local.backends.base import DataFrameBackend from feast.infra.compute_engines.local.nodes import ( @@ -95,25 +94,3 @@ def build_output_nodes(self, input_node): node = LocalOutputNode("output") node.add_input(input_node) self.nodes.append(node) - - def build(self) -> ExecutionPlan: - last_node = self.build_source_node() - - if isinstance(self.task, HistoricalRetrievalTask): - last_node = self.build_join_node(last_node) - - last_node = self.build_filter_node(last_node) - - if self._should_aggregate(): - last_node = self.build_aggregation_node(last_node) - elif isinstance(self.task, HistoricalRetrievalTask): - last_node = self.build_dedup_node(last_node) - - if self._should_transform(): - last_node = self.build_transformation_node(last_node) - - if self._should_validate(): - last_node = self.build_validation_node(last_node) - - self.build_output_nodes(last_node) - return ExecutionPlan(self.nodes) diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py index 4e1d2c3362f..aea83921351 100644 --- a/sdk/python/feast/infra/compute_engines/local/nodes.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -1,8 +1,9 @@ -from datetime import timedelta +from datetime import datetime, timedelta from typing import Optional import pyarrow as pa +from feast.data_source import DataSource from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue from feast.infra.compute_engines.local.backends.base import DataFrameBackend @@ -15,14 +16,40 @@ class LocalSourceReadNode(LocalNode): - def __init__(self, name: str, feature_view, task): + def __init__( + self, + name: str, + source: DataSource, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + ): super().__init__(name) - self.feature_view = feature_view - self.task = task + self.source = source + self.start_time = start_time + self.end_time = end_time def execute(self, context: ExecutionContext) -> ArrowTableValue: - # TODO : Implement the logic to read from offline store - return ArrowTableValue(data=pa.Table.from_pandas(context.entity_df)) + offline_store = context.offline_store + ( + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + ) = context.column_info + + # ๐Ÿ“ฅ Reuse Feast's robust query resolver + retrieval_job = offline_store.pull_all_from_table_or_query( + config=context.repo_config, + data_source=self.source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=self.start_time, + end_date=self.end_time, + ) + arrow_table = retrieval_job.to_arrow() + return ArrowTableValue(data=arrow_table) class LocalJoinNode(LocalNode): diff --git a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py index 453cee7fda5..944feccf903 100644 --- a/sdk/python/feast/infra/compute_engines/spark/feature_builder.py +++ b/sdk/python/feast/infra/compute_engines/spark/feature_builder.py @@ -9,9 +9,8 @@ SparkAggregationNode, SparkDedupNode, SparkFilterNode, - SparkHistoricalRetrievalReadNode, SparkJoinNode, - SparkMaterializationReadNode, + SparkReadNode, SparkTransformationNode, SparkWriteNode, ) @@ -27,12 +26,10 @@ def __init__( self.spark_session = spark_session def build_source_node(self): - if isinstance(self.task, MaterializationTask): - node = SparkMaterializationReadNode("source", self.task) - else: - node = SparkHistoricalRetrievalReadNode( - "source", self.task, self.spark_session - ) + source = self.feature_view.batch_source + start_time = self.task.start_time + end_time = self.task.end_time + node = SparkReadNode("source", source, start_time, end_time) self.nodes.append(node) return node diff --git a/sdk/python/feast/infra/compute_engines/spark/node.py b/sdk/python/feast/infra/compute_engines/spark/node.py index e0215081bcf..0c1c1476613 100644 --- a/sdk/python/feast/infra/compute_engines/spark/node.py +++ b/sdk/python/feast/infra/compute_engines/spark/node.py @@ -1,4 +1,4 @@ -from datetime import timedelta +from datetime import datetime, timedelta from typing import List, Optional, Union, cast from pyspark.sql import DataFrame, SparkSession, Window @@ -6,8 +6,7 @@ from feast import BatchFeatureView, StreamFeatureView from feast.aggregation import Aggregation -from feast.infra.common.materialization_job import MaterializationTask -from feast.infra.common.retrieval_task import HistoricalRetrievalTask +from feast.data_source import DataSource from feast.infra.compute_engines.dag.context import ExecutionContext from feast.infra.compute_engines.dag.model import DAGFormat from feast.infra.compute_engines.dag.node import DAGNode @@ -23,7 +22,6 @@ from feast.infra.offline_stores.offline_utils import ( infer_event_timestamp_from_entity_df, ) -from feast.utils import _get_fields_with_aliases ENTITY_TS_ALIAS = "__entity_event_timestamp" @@ -49,18 +47,21 @@ def rename_entity_ts_column( return entity_df -class SparkMaterializationReadNode(DAGNode): +class SparkReadNode(DAGNode): def __init__( - self, name: str, task: Union[MaterializationTask, HistoricalRetrievalTask] + self, + name: str, + source: DataSource, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, ): super().__init__(name) - self.task = task + self.source = source + self.start_time = start_time + self.end_time = end_time def execute(self, context: ExecutionContext) -> DAGValue: offline_store = context.offline_store - start_time = self.task.start_time - end_time = self.task.end_time - ( join_key_columns, feature_name_columns, @@ -69,15 +70,15 @@ def execute(self, context: ExecutionContext) -> DAGValue: ) = context.column_info # ๐Ÿ“ฅ Reuse Feast's robust query resolver - retrieval_job = offline_store.pull_latest_from_table_or_query( + retrieval_job = offline_store.pull_all_from_table_or_query( config=context.repo_config, - data_source=self.task.feature_view.batch_source, + data_source=self.source, join_key_columns=join_key_columns, feature_name_columns=feature_name_columns, timestamp_field=timestamp_field, created_timestamp_column=created_timestamp_column, - start_date=start_time, - end_date=end_time, + start_date=self.start_time, + end_date=self.end_time, ) spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() @@ -88,74 +89,8 @@ def execute(self, context: ExecutionContext) -> DAGValue: "source": "feature_view_batch_source", "timestamp_field": timestamp_field, "created_timestamp_column": created_timestamp_column, - "start_date": start_time, - "end_date": end_time, - }, - ) - - -class SparkHistoricalRetrievalReadNode(DAGNode): - def __init__( - self, name: str, task: HistoricalRetrievalTask, spark_session: SparkSession - ): - super().__init__(name) - self.task = task - self.spark_session = spark_session - - def execute(self, context: ExecutionContext) -> DAGValue: - """ - Read data from the offline store on the Spark engine. - TODO: Some functionality is duplicated with SparkMaterializationReadNode and spark get_historical_features. - Args: - context: SparkExecutionContext - Returns: DAGValue - """ - fv = self.task.feature_view - source = fv.batch_source - - ( - join_key_columns, - feature_name_columns, - timestamp_field, - created_timestamp_column, - ) = context.column_info - - # TODO: Use pull_all_from_table_or_query when it supports not filtering by timestamp - # retrieval_job = offline_store.pull_all_from_table_or_query( - # config=context.repo_config, - # data_source=source, - # join_key_columns=join_key_columns, - # feature_name_columns=feature_name_columns, - # timestamp_field=timestamp_field, - # start_date=min_ts, - # end_date=max_ts, - # ) - # spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df() - - columns = join_key_columns + feature_name_columns + [timestamp_field] - if created_timestamp_column: - columns.append(created_timestamp_column) - - (fields_with_aliases, aliases) = _get_fields_with_aliases( - fields=columns, - field_mappings=source.field_mapping, - ) - fields_with_alias_string = ", ".join(fields_with_aliases) - - from_expression = source.get_table_query_string() - - query = f""" - SELECT {fields_with_alias_string} - FROM {from_expression} - """ - spark_df = self.spark_session.sql(query) - - return DAGValue( - data=spark_df, - format=DAGFormat.SPARK, - metadata={ - "source": "feature_view_batch_source", - "timestamp_field": timestamp_field, + "start_date": self.start_time, + "end_date": self.end_time, }, ) @@ -227,7 +162,12 @@ def execute(self, context: ExecutionContext) -> DAGValue: feature_df: DataFrame = feature_value.data entity_df = context.entity_df - assert entity_df is not None, "entity_df must be set in ExecutionContext" + if entity_df is None: + return DAGValue( + data=feature_df, + format=DAGFormat.SPARK, + metadata={"joined_on": None}, + ) # Get timestamp fields from feature view join_keys, feature_cols, ts_col, created_ts_col = context.column_info @@ -272,13 +212,13 @@ def execute(self, context: ExecutionContext) -> DAGValue: if ENTITY_TS_ALIAS in input_df.columns: filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(ENTITY_TS_ALIAS)) - # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl - if self.ttl: - ttl_seconds = int(self.ttl.total_seconds()) - lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr( - f"INTERVAL {ttl_seconds} seconds" - ) - filtered_df = filtered_df.filter(F.col(ts_col) >= lower_bound) + # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl + if self.ttl: + ttl_seconds = int(self.ttl.total_seconds()) + lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr( + f"INTERVAL {ttl_seconds} seconds" + ) + filtered_df = filtered_df.filter(F.col(ts_col) >= lower_bound) # Optional custom filter condition if self.filter_condition: diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index f0516b594ee..2c4bc8cdbc3 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -52,6 +52,7 @@ BigQuerySource, SavedDatasetBigQueryStorage, ) +from .offline_utils import get_timestamp_filter_sql try: from google.api_core import client_info as http_client_info @@ -188,8 +189,9 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: assert isinstance(config.offline_store, BigQueryOfflineStoreConfig) assert isinstance(data_source, BigQuerySource) @@ -201,15 +203,26 @@ def pull_all_from_table_or_query( project=project_id, location=config.offline_store.location, ) + + timestamp_fields = [timestamp_field] + if created_timestamp_column: + timestamp_fields.append(created_timestamp_column) field_string = ", ".join( BigQueryOfflineStore._escape_query_columns(join_key_columns) + BigQueryOfflineStore._escape_query_columns(feature_name_columns) - + [timestamp_field] + + timestamp_fields + ) + timestamp_filter = get_timestamp_filter_sql( + start_date, + end_date, + timestamp_field, + quote_fields=False, + cast_style="timestamp_func", ) query = f""" SELECT {field_string} FROM {from_expression} - WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}') + WHERE {timestamp_filter} """ return BigQueryRetrievalJob( query=query, diff --git a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py index f49bfddb81d..6f92f45793b 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py +++ b/sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py @@ -36,6 +36,7 @@ RetrievalJob, RetrievalMetadata, ) +from feast.infra.offline_stores.offline_utils import get_timestamp_filter_sql from feast.infra.registry.base_registry import BaseRegistry from feast.infra.utils import aws_utils from feast.repo_config import FeastConfigBaseModel, RepoConfig @@ -131,15 +132,19 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: assert isinstance(config.offline_store, AthenaOfflineStoreConfig) assert isinstance(data_source, AthenaSource) from_expression = data_source.get_table_query_string(config) + timestamp_fields = [timestamp_field] + if created_timestamp_column: + timestamp_fields.append(created_timestamp_column) field_string = ", ".join( - join_key_columns + feature_name_columns + [timestamp_field] + join_key_columns + feature_name_columns + timestamp_fields ) athena_client = aws_utils.get_athena_data_client(config.offline_store.region) @@ -147,11 +152,30 @@ def pull_all_from_table_or_query( date_partition_column = data_source.date_partition_column + start_date_str = None + if start_date: + start_date_str = start_date.astimezone(tz=timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S.%f" + )[:-3] + end_date_str = None + if end_date: + end_date_str = end_date.astimezone(tz=timezone.utc).strftime( + "%Y-%m-%d %H:%M:%S.%f" + )[:-3] + + timestamp_filter = get_timestamp_filter_sql( + start_date_str, + end_date_str, + timestamp_field, + date_partition_column, + cast_style="raw", + quote_fields=False, + ) + query = f""" SELECT {field_string} FROM {from_expression} - WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date.astimezone(tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}' AND TIMESTAMP '{end_date.astimezone(tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}' - {"AND " + date_partition_column + " >= '" + start_date.strftime("%Y-%m-%d") + "' AND " + date_partition_column + " <= '" + end_date.strftime("%Y-%m-%d") + "' " if date_partition_column != "" and date_partition_column is not None else ""} + WHERE {timestamp_filter} """ return AthenaRetrievalJob( diff --git a/sdk/python/feast/infra/offline_stores/contrib/couchbase_offline_store/couchbase.py b/sdk/python/feast/infra/offline_stores/contrib/couchbase_offline_store/couchbase.py index a90d6c2172b..54921e9515e 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/couchbase_offline_store/couchbase.py +++ b/sdk/python/feast/infra/offline_stores/contrib/couchbase_offline_store/couchbase.py @@ -42,6 +42,7 @@ from feast.saved_dataset import SavedDatasetStorage from ... import offline_utils +from ...offline_utils import get_timestamp_filter_sql from .couchbase_source import ( CouchbaseColumnarSource, SavedDatasetCouchbaseColumnarStorage, @@ -228,8 +229,9 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: """ Fetch all rows from the specified table or query within the time range. @@ -243,17 +245,28 @@ def pull_all_from_table_or_query( assert isinstance(data_source, CouchbaseColumnarSource) from_expression = data_source.get_table_query_string() + timestamp_fields = [timestamp_field] + if created_timestamp_column: + timestamp_fields.append(created_timestamp_column) field_string = ", ".join( - join_key_columns + feature_name_columns + [timestamp_field] + join_key_columns + feature_name_columns + timestamp_fields + ) + start_date_normalized = ( + f"`{normalize_timestamp(start_date)}`" if start_date else None + ) + end_date_normalized = f"`{normalize_timestamp(end_date)}`" if end_date else None + timestamp_filter = get_timestamp_filter_sql( + start_date_normalized, + end_date_normalized, + timestamp_field, + cast_style="raw", + quote_fields=False, ) - - start_date_normalized = normalize_timestamp(start_date) - end_date_normalized = normalize_timestamp(end_date) query = f""" SELECT {field_string} FROM {from_expression} - WHERE `{timestamp_field}` BETWEEN '{start_date_normalized}' AND '{end_date_normalized}' + WHERE {timestamp_filter} """ return CouchbaseColumnarRetrievalJob( diff --git a/sdk/python/feast/infra/offline_stores/contrib/mssql_offline_store/mssql.py b/sdk/python/feast/infra/offline_stores/contrib/mssql_offline_store/mssql.py index 875d584568b..4821aa8dcb6 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/mssql_offline_store/mssql.py +++ b/sdk/python/feast/infra/offline_stores/contrib/mssql_offline_store/mssql.py @@ -177,8 +177,9 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: return pull_all_from_table_or_query_ibis( config=config, @@ -186,6 +187,7 @@ def pull_all_from_table_or_query( join_key_columns=join_key_columns, feature_name_columns=feature_name_columns, timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, start_date=start_date, end_date=end_date, data_source_reader=_build_data_source_reader(config), diff --git a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py index dc5a7b30976..1a75bb7e178 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py +++ b/sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py @@ -34,6 +34,7 @@ RetrievalJob, RetrievalMetadata, ) +from feast.infra.offline_stores.offline_utils import get_timestamp_filter_sql from feast.infra.registry.base_registry import BaseRegistry from feast.infra.utils.postgres.connection_utils import ( _get_conn, @@ -226,24 +227,34 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: assert isinstance(config.offline_store, PostgreSQLOfflineStoreConfig) assert isinstance(data_source, PostgreSQLSource) from_expression = data_source.get_table_query_string() + timestamp_fields = [timestamp_field] + if created_timestamp_column: + timestamp_fields.append(created_timestamp_column) field_string = ", ".join( - join_key_columns + feature_name_columns + [timestamp_field] + join_key_columns + feature_name_columns + timestamp_fields ) - start_date = start_date.astimezone(tz=timezone.utc) - end_date = end_date.astimezone(tz=timezone.utc) + timestamp_filter = get_timestamp_filter_sql( + start_date, + end_date, + timestamp_field, + tz=timezone.utc, + cast_style="timestamptz", + date_time_separator=" ", # backwards compatibility but inconsistent with other offline stores + ) query = f""" SELECT {field_string} FROM {from_expression} AS paftoq_alias - WHERE "{timestamp_field}" BETWEEN '{start_date}'::timestamptz AND '{end_date}'::timestamptz + WHERE {timestamp_filter} """ return PostgreSQLRetrievalJob( diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 41c180f5c3c..806610cae7e 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -29,6 +29,7 @@ RetrievalJob, RetrievalMetadata, ) +from feast.infra.offline_stores.offline_utils import get_timestamp_filter_sql from feast.infra.registry.base_registry import BaseRegistry from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage @@ -269,8 +270,9 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: """ Note that join_key_columns, feature_name_columns, timestamp_field, and @@ -288,21 +290,26 @@ def pull_all_from_table_or_query( spark_session = get_spark_session_or_start_new_with_repoconfig( store_config=config.offline_store ) + + timestamp_fields = [timestamp_field] + if created_timestamp_column: + timestamp_fields.append(created_timestamp_column) (fields_with_aliases, aliases) = _get_fields_with_aliases( - fields=join_key_columns + feature_name_columns + [timestamp_field], + fields=join_key_columns + feature_name_columns + timestamp_fields, field_mappings=data_source.field_mapping, ) fields_with_alias_string = ", ".join(fields_with_aliases) from_expression = data_source.get_table_query_string() - start_date = start_date.astimezone(tz=timezone.utc) - end_date = end_date.astimezone(tz=timezone.utc) + timestamp_filter = get_timestamp_filter_sql( + start_date, end_date, timestamp_field, tz=timezone.utc, quote_fields=False + ) query = f""" SELECT {fields_with_alias_string} FROM {from_expression} - WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}' + WHERE {timestamp_filter} """ return SparkRetrievalJob( diff --git a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py index 9667f4e4720..7f7b91d1d23 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py +++ b/sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py @@ -31,6 +31,7 @@ RetrievalJob, RetrievalMetadata, ) +from feast.infra.offline_stores.offline_utils import get_timestamp_filter_sql from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.repo_config import FeastConfigBaseModel, RepoConfig @@ -405,21 +406,30 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: assert isinstance(config.offline_store, TrinoOfflineStoreConfig) assert isinstance(data_source, TrinoSource) from_expression = data_source.get_table_query_string() client = _get_trino_client(config=config) + + timestamp_fields = [timestamp_field] + if created_timestamp_column: + timestamp_fields.append(created_timestamp_column) field_string = ", ".join( - join_key_columns + feature_name_columns + [timestamp_field] + join_key_columns + feature_name_columns + timestamp_fields + ) + + timestamp_filter = get_timestamp_filter_sql( + start_date, end_date, timestamp_field, quote_fields=False ) query = f""" SELECT {field_string} FROM {from_expression} - WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}' + WHERE {timestamp_filter} """ return TrinoRetrievalJob( query=query, diff --git a/sdk/python/feast/infra/offline_stores/dask.py b/sdk/python/feast/infra/offline_stores/dask.py index 01efc492f7c..ea857996966 100644 --- a/sdk/python/feast/infra/offline_stores/dask.py +++ b/sdk/python/feast/infra/offline_stores/dask.py @@ -314,8 +314,8 @@ def pull_latest_from_table_or_query( feature_name_columns: List[str], timestamp_field: str, created_timestamp_column: Optional[str], - start_date: datetime, - end_date: datetime, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: assert isinstance(config.offline_store, DaskOfflineStoreConfig) assert isinstance(data_source, FileSource) @@ -359,10 +359,19 @@ def evaluate_offline_job(): source_df = source_df.sort_values(by=timestamp_field, npartitions=1) - source_df = source_df[ - (source_df[timestamp_field] >= start_date) - & (source_df[timestamp_field] < end_date) - ] + # TODO: The old implementation is inclusive of start_date and exclusive of end_date. + # Which is inconsistent with other offline stores. + if start_date or end_date: + if start_date and end_date: + source_df = source_df[ + source_df[timestamp_field].between( + start_date, end_date, inclusive="left" + ) + ] + elif start_date: + source_df = source_df[source_df[timestamp_field] >= start_date] + elif end_date: + source_df = source_df[source_df[timestamp_field] < end_date] source_df = source_df.persist() @@ -393,8 +402,9 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: assert isinstance(config.offline_store, DaskOfflineStoreConfig) assert isinstance(data_source, FileSource) @@ -406,7 +416,7 @@ def pull_all_from_table_or_query( + [timestamp_field], # avoid deduplication feature_name_columns=feature_name_columns, timestamp_field=timestamp_field, - created_timestamp_column=None, + created_timestamp_column=created_timestamp_column, start_date=start_date, end_date=end_date, ) diff --git a/sdk/python/feast/infra/offline_stores/duckdb.py b/sdk/python/feast/infra/offline_stores/duckdb.py index b2e3c03cb55..7bf96129d0b 100644 --- a/sdk/python/feast/infra/offline_stores/duckdb.py +++ b/sdk/python/feast/infra/offline_stores/duckdb.py @@ -179,8 +179,9 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: return pull_all_from_table_or_query_ibis( config=config, @@ -188,6 +189,7 @@ def pull_all_from_table_or_query( join_key_columns=join_key_columns, feature_name_columns=feature_name_columns, timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, start_date=start_date, end_date=end_date, data_source_reader=_read_data_source, diff --git a/sdk/python/feast/infra/offline_stores/ibis.py b/sdk/python/feast/infra/offline_stores/ibis.py index 66d00ca6292..95c5afef2db 100644 --- a/sdk/python/feast/infra/offline_stores/ibis.py +++ b/sdk/python/feast/infra/offline_stores/ibis.py @@ -260,16 +260,22 @@ def pull_all_from_table_or_query_ibis( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, data_source_reader: Callable[[DataSource, str], Table], data_source_writer: Callable[[pyarrow.Table, DataSource, str], None], + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, staging_location: Optional[str] = None, staging_location_endpoint_override: Optional[str] = None, ) -> RetrievalJob: - fields = join_key_columns + feature_name_columns + [timestamp_field] - start_date = start_date.astimezone(tz=timezone.utc) - end_date = end_date.astimezone(tz=timezone.utc) + timestamp_fields = [timestamp_field] + if created_timestamp_column: + timestamp_fields.append(created_timestamp_column) + fields = join_key_columns + feature_name_columns + timestamp_fields + if start_date: + start_date = start_date.astimezone(tz=timezone.utc) + if end_date: + end_date = end_date.astimezone(tz=timezone.utc) table = data_source_reader(data_source, str(config.repo_path)) @@ -281,8 +287,12 @@ def pull_all_from_table_or_query_ibis( table = table.filter( ibis.and_( - table[timestamp_field] >= ibis.literal(start_date), - table[timestamp_field] <= ibis.literal(end_date), + table[timestamp_field] >= ibis.literal(start_date) + if start_date + else ibis.literal(True), + table[timestamp_field] <= ibis.literal(end_date) + if end_date + else ibis.literal(True), ) ) diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index 69d6bb278b7..73794f67a17 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -294,8 +294,9 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: """ Extracts all the entity rows (i.e. the combination of join key columns, feature columns, and @@ -309,9 +310,10 @@ def pull_all_from_table_or_query( data_source: The data source from which the entity rows will be extracted. join_key_columns: The columns of the join keys. feature_name_columns: The columns of the features. - timestamp_field: The timestamp column. - start_date: The start of the time range. - end_date: The end of the time range. + timestamp_field: The timestamp column, used to determine which rows are the most recent. + created_timestamp_column (Optional): The column indicating when the row was created, used to break ties. + start_date (Optional): The start of the time range. + end_date (Optional): The end of the time range. Returns: A RetrievalJob that can be executed to get the entity rows. diff --git a/sdk/python/feast/infra/offline_stores/offline_utils.py b/sdk/python/feast/infra/offline_stores/offline_utils.py index 5b12636782f..e951434e2a3 100644 --- a/sdk/python/feast/infra/offline_stores/offline_utils.py +++ b/sdk/python/feast/infra/offline_stores/offline_utils.py @@ -1,7 +1,7 @@ import uuid from dataclasses import asdict, dataclass -from datetime import datetime, timedelta -from typing import Any, Dict, KeysView, List, Optional, Set, Tuple +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, KeysView, List, Literal, Optional, Set, Tuple, Union import numpy as np import pandas as pd @@ -266,3 +266,91 @@ def enclose_in_backticks(value): return [f"`{v}`" for v in value] else: return f"`{value}`" + + +def get_timestamp_filter_sql( + start_date: Optional[Union[datetime, str]] = None, + end_date: Optional[Union[datetime, str]] = None, + timestamp_field: Optional[str] = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, + date_partition_column: Optional[str] = None, + tz: Optional[timezone] = None, + cast_style: Literal[ + "timestamp", "timestamp_func", "timestamptz", "raw" + ] = "timestamp", + date_time_separator: str = "T", + quote_fields: bool = True, +) -> str: + """ + Returns SQL filter condition (no WHERE) with flexible timestamp casting. + + Args: + start_date: datetime or ISO8601 strings + end_date: datetime or ISO8601 strings + timestamp_field: main timestamp column + date_partition_column: optional partition column (for pruning) + tz: optional timezone for datetime inputs + cast_style: one of: + - "timestamp": TIMESTAMP '...' โ†’ Common Sql engine Snowflake, Redshift etc. + - "timestamp_func": TIMESTAMP('...') โ†’ BigQuery, Couchbase etc. + - "timestamptz": '...'::timestamptz โ†’ PostgreSQL + - "raw": '...' โ†’ no cast, string only + date_time_separator: separator for datetime strings (default is "T") + (e.g. "2023-10-01T00:00:00" or "2023-10-01 00:00:00") + quote_fields: whether to quote the timestamp and partition column names + + Returns: + SQL filter string without WHERE + """ + + def quote_column_if_needed(column: Optional[str]) -> Optional[str]: + if not column or not quote_fields: + return column + return f'"{column}"' + + def format_casted_ts(val: Union[str, datetime]) -> str: + if isinstance(val, datetime): + if tz: + val = val.astimezone(tz) + val_str = val.isoformat(sep=date_time_separator) + else: + val_str = val + + if cast_style == "timestamp": + return f"TIMESTAMP '{val_str}'" + elif cast_style == "timestamp_func": + return f"TIMESTAMP('{val_str}')" + elif cast_style == "timestamptz": + return f"'{val_str}'::{cast_style}" + else: + return f"'{val_str}'" + + def format_date(val: Union[str, datetime]) -> str: + if isinstance(val, datetime): + if tz: + val = val.astimezone(tz) + return val.strftime("%Y-%m-%d") + return val + + ts_field = quote_column_if_needed(timestamp_field) + dp_field = quote_column_if_needed(date_partition_column) + + filters = [] + + # Timestamp filters + if start_date and end_date: + filters.append( + f"{ts_field} BETWEEN {format_casted_ts(start_date)} AND {format_casted_ts(end_date)}" + ) + elif start_date: + filters.append(f"{ts_field} >= {format_casted_ts(start_date)}") + elif end_date: + filters.append(f"{ts_field} <= {format_casted_ts(end_date)}") + + # Partition pruning + if date_partition_column: + if start_date: + filters.append(f"{dp_field} >= '{format_date(start_date)}'") + if end_date: + filters.append(f"{dp_field} <= '{format_date(end_date)}'") + + return " AND ".join(filters) if filters else "" diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index ed76f830f3b..4ed8e6309c4 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -33,6 +33,7 @@ RetrievalJob, RetrievalMetadata, ) +from feast.infra.offline_stores.offline_utils import get_timestamp_filter_sql from feast.infra.offline_stores.redshift_source import ( RedshiftLoggingDestination, SavedDatasetRedshiftStorage, @@ -157,15 +158,19 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: assert isinstance(config.offline_store, RedshiftOfflineStoreConfig) assert isinstance(data_source, RedshiftSource) from_expression = data_source.get_table_query_string() + timestamp_fields = [timestamp_field] + if created_timestamp_column: + timestamp_fields.append(created_timestamp_column) field_string = ", ".join( - join_key_columns + feature_name_columns + [timestamp_field] + join_key_columns + feature_name_columns + timestamp_fields ) redshift_client = aws_utils.get_redshift_data_client( @@ -173,13 +178,17 @@ def pull_all_from_table_or_query( ) s3_resource = aws_utils.get_s3_resource(config.offline_store.region) - start_date = start_date.astimezone(tz=timezone.utc) - end_date = end_date.astimezone(tz=timezone.utc) + timestamp_filter = get_timestamp_filter_sql( + start_date, + end_date, + timestamp_field, + tz=timezone.utc, + ) query = f""" SELECT {field_string} FROM {from_expression} - WHERE {timestamp_field} BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}' + WHERE {timestamp_filter} """ return RedshiftRetrievalJob( diff --git a/sdk/python/feast/infra/offline_stores/remote.py b/sdk/python/feast/infra/offline_stores/remote.py index d11fb4673db..41985b9bba0 100644 --- a/sdk/python/feast/infra/offline_stores/remote.py +++ b/sdk/python/feast/infra/offline_stores/remote.py @@ -234,8 +234,9 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: assert isinstance(config.offline_store, RemoteOfflineStoreConfig) @@ -253,8 +254,9 @@ def pull_all_from_table_or_query( "join_key_columns": join_key_columns, "feature_name_columns": feature_name_columns, "timestamp_field": timestamp_field, - "start_date": start_date.isoformat(), - "end_date": end_date.isoformat(), + "created_timestamp_column": created_timestamp_column, + "start_date": start_date.isoformat() if start_date else None, + "end_date": end_date.isoformat() if end_date else None, } return RemoteRetrievalJob( diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index 101685cec6f..3a39d0ea6db 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -37,6 +37,7 @@ RetrievalJob, RetrievalMetadata, ) +from feast.infra.offline_stores.offline_utils import get_timestamp_filter_sql from feast.infra.offline_stores.snowflake_source import ( SavedDatasetSnowflakeStorage, SnowflakeLoggingDestination, @@ -229,8 +230,9 @@ def pull_all_from_table_or_query( join_key_columns: List[str], feature_name_columns: List[str], timestamp_field: str, - start_date: datetime, - end_date: datetime, + created_timestamp_column: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, ) -> RetrievalJob: assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig) assert isinstance(data_source, SnowflakeSource) @@ -241,22 +243,26 @@ def pull_all_from_table_or_query( if not data_source.database and data_source.schema and data_source.table: from_expression = f'"{config.offline_store.database}".{from_expression}' + timestamp_fields = [timestamp_field] + if created_timestamp_column: + timestamp_fields.append(created_timestamp_column) field_string = ( '"' - + '", "'.join(join_key_columns + feature_name_columns + [timestamp_field]) + + '", "'.join(join_key_columns + feature_name_columns + timestamp_fields) + '"' ) with GetSnowflakeConnection(config.offline_store) as conn: snowflake_conn = conn - start_date = start_date.astimezone(tz=timezone.utc) - end_date = end_date.astimezone(tz=timezone.utc) + timestamp_filter = get_timestamp_filter_sql( + start_date, end_date, timestamp_field, tz=timezone.utc + ) query = f""" SELECT {field_string} FROM {from_expression} - WHERE "{timestamp_field}" BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}' + WHERE {timestamp_filter} """ return SnowflakeRetrievalJob( diff --git a/sdk/python/feast/offline_server.py b/sdk/python/feast/offline_server.py index f3642e5812e..f3215ca0e47 100644 --- a/sdk/python/feast/offline_server.py +++ b/sdk/python/feast/offline_server.py @@ -354,13 +354,15 @@ def pull_all_from_table_or_query(self, command: dict): assert_permissions(data_source, actions=[AuthzedAction.READ_OFFLINE]) return self.offline_store.pull_all_from_table_or_query( - self.store.config, - data_source, - command["join_key_columns"], - command["feature_name_columns"], - command["timestamp_field"], - utils.make_tzaware(datetime.fromisoformat(command["start_date"])), - utils.make_tzaware(datetime.fromisoformat(command["end_date"])), + config=self.store.config, + data_source=data_source, + join_key_columns=command["join_key_columns"], + feature_name_columns=command["feature_name_columns"], + timestamp_field=command["timestamp_field"], + start_date=utils.make_tzaware( + datetime.fromisoformat(command["start_date"]) + ), + end_date=utils.make_tzaware(datetime.fromisoformat(command["end_date"])), ) def _validate_pull_latest_from_table_or_query_parameters(self, command: dict): diff --git a/sdk/python/tests/integration/compute_engines/spark/test_compute.py b/sdk/python/tests/integration/compute_engines/spark/test_compute.py index c6aef9e5701..15b6e850c65 100644 --- a/sdk/python/tests/integration/compute_engines/spark/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/spark/test_compute.py @@ -161,8 +161,6 @@ def transform_feature(df: DataFrame) -> DataFrame: feature_view=driver_stats_fv, full_feature_name=False, registry=registry, - start_time=now - timedelta(days=1), - end_time=now, ) # ๐Ÿงช Run SparkComputeEngine @@ -228,7 +226,7 @@ def tqdm_builder(length): task = MaterializationTask( project=spark_environment.project, feature_view=driver_stats_fv, - start_time=now - timedelta(days=1), + start_time=now - timedelta(days=2), end_time=now, tqdm_builder=tqdm_builder, )