From 8a2cfab25e2d498d6cead543594d7a2200b01a5e Mon Sep 17 00:00:00 2001 From: ntkathole Date: Thu, 18 Sep 2025 14:01:07 +0530 Subject: [PATCH] feat: Added kuberay support Signed-off-by: ntkathole --- docs/reference/compute-engine/ray.md | 17 +- docs/reference/offline-stores/ray.md | 124 +++- .../infra/compute_engines/ray/compute.py | 41 +- .../feast/infra/compute_engines/ray/config.py | 19 +- .../feast/infra/compute_engines/ray/job.py | 7 +- .../feast/infra/compute_engines/ray/nodes.py | 13 +- .../contrib/ray_offline_store/ray.py | 130 ++-- sdk/python/feast/infra/ray_initializer.py | 660 ++++++++++++++++++ sdk/python/feast/infra/ray_shared_utils.py | 17 +- .../ray_compute/ray_shared_utils.py | 10 +- .../ray_compute/test_compute.py | 1 - .../compute_engines/ray_compute/test_nodes.py | 106 +++ 12 files changed, 1000 insertions(+), 145 deletions(-) create mode 100644 sdk/python/feast/infra/ray_initializer.py diff --git a/docs/reference/compute-engine/ray.md b/docs/reference/compute-engine/ray.md index 4ecc449e40b..5547901b873 100644 --- a/docs/reference/compute-engine/ray.md +++ b/docs/reference/compute-engine/ray.md @@ -62,11 +62,22 @@ batch_engine: | `max_parallelism_multiplier` | int | 2 | Parallelism as multiple of CPU cores | | `target_partition_size_mb` | int | 64 | Target partition size (MB) | | `window_size_for_joins` | string | "1H" | Time window for distributed joins | -| `ray_address` | string | None | Ray cluster address (None = local Ray) | +| `ray_address` | string | None | Ray cluster address (triggers REMOTE mode) | +| `use_kuberay` | boolean | None | Enable KubeRay mode (overrides ray_address) | +| `kuberay_conf` | dict | None | **KubeRay configuration dict** with keys: `cluster_name` (required), `namespace` (default: "default"), `auth_token`, `auth_server`, `skip_tls` (default: false) | +| `enable_ray_logging` | boolean | false | Enable Ray progress bars and logging | | `enable_distributed_joins` | boolean | true | Enable distributed joins for large datasets | | `staging_location` | string | None | Remote path for batch materialization jobs | -| `ray_conf` | dict | None | Ray configuration parameters | -| `execution_timeout_seconds` | int | None | Timeout for job execution in seconds | +| `ray_conf` | dict | None | Ray configuration parameters (memory, CPU limits) | + +### Mode Detection Precedence + +The Ray compute engine automatically detects the execution mode: + +1. **Environment Variables** → KubeRay mode (if `FEAST_RAY_USE_KUBERAY=true`) +2. **Config `kuberay_conf`** → KubeRay mode +3. **Config `ray_address`** → Remote mode +4. **Default** → Local mode ## Usage Examples diff --git a/docs/reference/offline-stores/ray.md b/docs/reference/offline-stores/ray.md index 58f62c34ece..a46102ee132 100644 --- a/docs/reference/offline-stores/ray.md +++ b/docs/reference/offline-stores/ray.md @@ -9,10 +9,11 @@ The Ray offline store is a data I/O implementation that leverages [Ray](https:// The Ray offline store provides: - Ray-based data reading from file sources (Parquet, CSV, etc.) -- Support for both local and distributed Ray clusters +- Support for local, remote, and KubeRay (Kubernetes-managed) clusters - Integration with various storage backends (local files, S3, GCS, HDFS) - Efficient data filtering and column selection - Timestamp-based data processing with timezone awareness +- Enterprise-ready KubeRay cluster support via CodeFlare SDK ## Functionality Matrix @@ -59,9 +60,15 @@ For complex feature processing, historical feature retrieval, and distributed jo ## Configuration -The Ray offline store can be configured in your `feature_store.yaml` file. Below are two main configuration patterns: +The Ray offline store can be configured in your `feature_store.yaml` file. It supports **three execution modes**: -### Basic Ray Offline Store +1. **LOCAL**: Ray runs locally on the same machine (default) +2. **REMOTE**: Connects to a remote Ray cluster via `ray_address` +3. **KUBERAY**: Connects to Ray clusters on Kubernetes via CodeFlare SDK + +### Execution Modes + +#### Local Mode (Default) For simple data I/O operations without distributed processing: @@ -72,7 +79,44 @@ provider: local offline_store: type: ray storage_path: data/ray_storage # Optional: Path for storing datasets - ray_address: localhost:10001 # Optional: Ray cluster address +``` + +#### Remote Ray Cluster + +Connect to an existing Ray cluster: + +```yaml +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data + ray_address: "ray://my-cluster.example.com:10001" +``` + +#### KubeRay Cluster (Kubernetes) + +Connect to Ray clusters on Kubernetes using CodeFlare SDK: + +```yaml +offline_store: + type: ray + storage_path: s3://my-bucket/feast-data + use_kuberay: true + kuberay_conf: + cluster_name: "feast-ray-cluster" + namespace: "feast-system" + auth_token: "${RAY_AUTH_TOKEN}" + auth_server: "https://api.openshift.com:6443" + skip_tls: false + enable_ray_logging: false +``` + +**Environment Variables** (alternative to config file): +```bash +export FEAST_RAY_USE_KUBERAY=true +export FEAST_RAY_CLUSTER_NAME=feast-ray-cluster +export FEAST_RAY_AUTH_TOKEN=your-token +export FEAST_RAY_AUTH_SERVER=https://api.openshift.com:6443 +export FEAST_RAY_NAMESPACE=feast-system ``` ### Ray Offline Store + Compute Engine @@ -175,8 +219,29 @@ batch_engine: |--------|------|---------|-------------| | `type` | string | Required | Must be `feast.offline_stores.contrib.ray_offline_store.ray.RayOfflineStore` or `ray` | | `storage_path` | string | None | Path for storing temporary files and datasets | -| `ray_address` | string | None | Address of the Ray cluster (e.g., "localhost:10001") | +| `ray_address` | string | None | Ray cluster address (triggers REMOTE mode, e.g., "ray://host:10001") | +| `use_kuberay` | boolean | None | Enable KubeRay mode (overrides ray_address) | +| `kuberay_conf` | dict | None | **KubeRay configuration dict** with keys: `cluster_name` (required), `namespace` (default: "default"), `auth_token`, `auth_server`, `skip_tls` (default: false) | +| `enable_ray_logging` | boolean | false | Enable Ray progress bars and verbose logging | | `ray_conf` | dict | None | Ray initialization parameters for resource management (e.g., memory, CPU limits) | +| `broadcast_join_threshold_mb` | int | 100 | Size threshold for broadcast joins (MB) | +| `enable_distributed_joins` | boolean | true | Enable distributed joins for large datasets | +| `max_parallelism_multiplier` | int | 2 | Parallelism as multiple of CPU cores | +| `target_partition_size_mb` | int | 64 | Target partition size (MB) | +| `window_size_for_joins` | string | "1H" | Time window for distributed joins | + +#### Mode Detection Precedence + +The Ray offline store automatically detects the execution mode using the following precedence: + +1. **Environment Variables** (highest priority) + - `FEAST_RAY_USE_KUBERAY`, `FEAST_RAY_CLUSTER_NAME`, etc. +2. **Config `kuberay_conf`** + - If present → KubeRay mode +3. **Config `ray_address`** + - If present → Remote mode +4. **Default** + - Local mode (lowest priority) #### Ray Compute Engine Options @@ -385,6 +450,8 @@ job.persist(hdfs_storage, allow_overwrite=True) ### Using Ray Cluster +#### Standard Ray Cluster + To use Ray in cluster mode for distributed data access: 1. Start a Ray cluster: @@ -406,6 +473,53 @@ offline_store: ray start --address='head-node-ip:10001' ``` +#### KubeRay Cluster (Kubernetes) + +To use Feast with Ray clusters on Kubernetes via CodeFlare SDK: + +**Prerequisites:** +- KubeRay cluster deployed on Kubernetes +- CodeFlare SDK installed: `pip install codeflare-sdk` +- Access credentials for the Kubernetes cluster + +**Configuration:** + +1. Using configuration file: +```yaml +offline_store: + type: ray + use_kuberay: true + storage_path: s3://my-bucket/feast-data + kuberay_conf: + cluster_name: "feast-ray-cluster" + namespace: "feast-system" + auth_token: "${RAY_AUTH_TOKEN}" + auth_server: "https://api.openshift.com:6443" + skip_tls: false + enable_ray_logging: false +``` + +2. Using environment variables: +```bash +export FEAST_RAY_USE_KUBERAY=true +export FEAST_RAY_CLUSTER_NAME=feast-ray-cluster +export FEAST_RAY_AUTH_TOKEN=your-k8s-token +export FEAST_RAY_AUTH_SERVER=https://api.openshift.com:6443 +export FEAST_RAY_NAMESPACE=feast-system +export FEAST_RAY_SKIP_TLS=false + +# Then use standard Feast code +python your_feast_script.py +``` + +**Features:** +- The CodeFlare SDK handles cluster connection and authentication +- Automatic TLS certificate management +- Authentication with Kubernetes clusters +- Namespace isolation +- Secure communication between client and Ray cluster +- Automatic cluster discovery + ### Data Source Validation The Ray offline store validates data sources to ensure compatibility: diff --git a/sdk/python/feast/infra/compute_engines/ray/compute.py b/sdk/python/feast/infra/compute_engines/ray/compute.py index 24d98cae7fb..a5c1b3caab5 100644 --- a/sdk/python/feast/infra/compute_engines/ray/compute.py +++ b/sdk/python/feast/infra/compute_engines/ray/compute.py @@ -2,8 +2,6 @@ from datetime import datetime from typing import Sequence, Union -import ray - from feast import ( BatchFeatureView, Entity, @@ -26,6 +24,10 @@ ) from feast.infra.compute_engines.ray.utils import write_to_online_store from feast.infra.offline_stores.offline_store import RetrievalJob +from feast.infra.ray_initializer import ( + ensure_ray_initialized, + get_ray_wrapper, +) from feast.infra.registry.base_registry import BaseRegistry logger = logging.getLogger(__name__) @@ -58,37 +60,7 @@ def __init__( def _ensure_ray_initialized(self): """Ensure Ray is initialized with proper configuration.""" - if not ray.is_initialized(): - if self.config.ray_address: - ray.init( - address=self.config.ray_address, - ignore_reinit_error=True, - include_dashboard=False, - ) - else: - ray_init_args = { - "ignore_reinit_error": True, - "include_dashboard": False, - } - - # Add configuration from ray_conf if provided - if self.config.ray_conf: - ray_init_args.update(self.config.ray_conf) - - ray.init(**ray_init_args) - - # Configure Ray context for optimal performance - from ray.data.context import DatasetContext - - ctx = DatasetContext.get_current() - ctx.enable_tensor_extension_casting = False - - # Log Ray cluster information - cluster_resources = ray.cluster_resources() - logger.info( - f"Ray cluster initialized with {cluster_resources.get('CPU', 0)} CPUs, " - f"{cluster_resources.get('memory', 0) / (1024**3):.1f}GB memory" - ) + ensure_ray_initialized(self.config) def update( self, @@ -230,7 +202,8 @@ def _materialize_from_offline_store( # Write to sink_source using Ray data try: - ray_dataset = ray.data.from_arrow(arrow_table) + ray_wrapper = get_ray_wrapper() + ray_dataset = ray_wrapper.from_arrow(arrow_table) ray_dataset.write_parquet(sink_source.path) except Exception as e: logger.error( diff --git a/sdk/python/feast/infra/compute_engines/ray/config.py b/sdk/python/feast/infra/compute_engines/ray/config.py index c6d74d262dd..bb6b63a05c5 100644 --- a/sdk/python/feast/infra/compute_engines/ray/config.py +++ b/sdk/python/feast/infra/compute_engines/ray/config.py @@ -46,9 +46,6 @@ class RayComputeEngineConfig(FeastConfigBaseModel): enable_optimization: bool = True """Enable automatic performance optimizations.""" - execution_timeout_seconds: Optional[int] = None - """Timeout for job execution in seconds.""" - @property def window_size_timedelta(self) -> timedelta: """Convert window size string to timedelta.""" @@ -64,3 +61,19 @@ def window_size_timedelta(self) -> timedelta: else: # Default to 1 hour return timedelta(hours=1) + + # KubeRay/CodeFlare SDK configurations + use_kuberay: Optional[bool] = None + """Whether to use KubeRay/CodeFlare SDK for Ray cluster management""" + + cluster_name: Optional[str] = None + """Name of the KubeRay cluster to connect to (required for KubeRay mode)""" + + auth_token: Optional[str] = None + """Authentication token for Ray cluster connection (for secure clusters)""" + + kuberay_conf: Optional[Dict[str, Any]] = None + """KubeRay/CodeFlare configuration parameters (passed to CodeFlare SDK)""" + + enable_ray_logging: bool = False + """Enable Ray progress bars and verbose logging""" diff --git a/sdk/python/feast/infra/compute_engines/ray/job.py b/sdk/python/feast/infra/compute_engines/ray/job.py index b2e88f1d5c5..06eea4e5d88 100644 --- a/sdk/python/feast/infra/compute_engines/ray/job.py +++ b/sdk/python/feast/infra/compute_engines/ray/job.py @@ -5,7 +5,6 @@ import pandas as pd import pyarrow as pa -import ray from ray.data import Dataset from feast import OnDemandFeatureView @@ -21,6 +20,7 @@ from feast.infra.compute_engines.dag.value import DAGValue from feast.infra.offline_stores.file_source import SavedDatasetFileStorage from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata +from feast.infra.ray_initializer import get_ray_wrapper from feast.repo_config import RepoConfig from feast.saved_dataset import SavedDatasetStorage @@ -69,10 +69,11 @@ def _ensure_executed(self) -> DAGValue: self._result_dataset = result.data else: # If result is not a Ray Dataset, convert it + ray_wrapper = get_ray_wrapper() if isinstance(result.data, pd.DataFrame): - self._result_dataset = ray.data.from_pandas(result.data) + self._result_dataset = ray_wrapper.from_pandas(result.data) elif isinstance(result.data, pa.Table): - self._result_dataset = ray.data.from_arrow(result.data) + self._result_dataset = ray_wrapper.from_arrow(result.data) else: raise ValueError( f"Unsupported result type: {type(result.data)}" diff --git a/sdk/python/feast/infra/compute_engines/ray/nodes.py b/sdk/python/feast/infra/compute_engines/ray/nodes.py index eaf48847113..54f75433843 100644 --- a/sdk/python/feast/infra/compute_engines/ray/nodes.py +++ b/sdk/python/feast/infra/compute_engines/ray/nodes.py @@ -23,6 +23,7 @@ write_to_online_store, ) from feast.infra.compute_engines.utils import create_offline_store_retrieval_job +from feast.infra.ray_initializer import get_ray_wrapper from feast.infra.ray_shared_utils import ( apply_field_mapping, broadcast_join, @@ -72,10 +73,12 @@ def execute(self, context: ExecutionContext) -> DAGValue: else: try: arrow_table = retrieval_job.to_arrow() - ray_dataset = ray.data.from_arrow(arrow_table) + ray_wrapper = get_ray_wrapper() + ray_dataset = ray_wrapper.from_arrow(arrow_table) except Exception: df = retrieval_job.to_df() - ray_dataset = ray.data.from_pandas(df) + ray_wrapper = get_ray_wrapper() + ray_dataset = ray_wrapper.from_pandas(df) field_mapping = getattr(self.source, "field_mapping", None) if field_mapping: @@ -130,7 +133,8 @@ def execute(self, context: ExecutionContext) -> DAGValue: entity_df = context.entity_df if isinstance(entity_df, pd.DataFrame): - entity_dataset = ray.data.from_pandas(entity_df) + ray_wrapper = get_ray_wrapper() + entity_dataset = ray_wrapper.from_pandas(entity_df) else: entity_dataset = entity_df @@ -423,7 +427,8 @@ def _fallback_pandas_aggregation(self, dataset: Dataset, agg_dict: dict) -> Data result_df = result_df.reset_index() # Convert back to Ray Dataset - return ray.data.from_pandas(result_df) + ray_wrapper = get_ray_wrapper() + return ray_wrapper.from_pandas(result_df) else: return dataset diff --git a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py index 1ec2853cf95..8d6582c7e02 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/ray.py @@ -11,7 +11,6 @@ import pandas as pd import pyarrow as pa import ray -import ray.data from ray.data import Dataset from ray.data.context import DatasetContext @@ -39,6 +38,10 @@ get_pyarrow_schema_from_batch_source, infer_event_timestamp_from_entity_df, ) +from feast.infra.ray_initializer import ( + ensure_ray_initialized, + get_ray_wrapper, +) from feast.infra.ray_shared_utils import ( _build_required_columns, apply_field_mapping, @@ -338,6 +341,19 @@ class RayOfflineStoreConfig(FeastConfigBaseModel): # Ray configuration for resource management (memory, CPU limits) ray_conf: Optional[Dict[str, Any]] = None + # KubeRay/CodeFlare SDK configurations + use_kuberay: Optional[bool] = None + """Whether to use KubeRay/CodeFlare SDK for Ray cluster management""" + + cluster_name: Optional[str] = None + """Name of the KubeRay cluster to connect to (required for KubeRay mode)""" + + auth_token: Optional[str] = None + """Authentication token for Ray cluster connection (for secure clusters)""" + + kuberay_conf: Optional[Dict[str, Any]] = None + """KubeRay/CodeFlare configuration parameters (passed to CodeFlare SDK)""" + class RayResourceManager: """ @@ -350,10 +366,18 @@ def __init__(self, config: Optional[RayOfflineStoreConfig] = None) -> None: Initialize the resource manager with cluster resource information. """ self.config = config or RayOfflineStoreConfig() + + if not ray.is_initialized(): + self.cluster_resources = {"CPU": 4, "memory": 8 * 1024**3} + self.available_memory = 8 * 1024**3 + self.available_cpus = 4 + self.num_nodes = 1 + return + self.cluster_resources = ray.cluster_resources() self.available_memory = self.cluster_resources.get("memory", 8 * 1024**3) self.available_cpus = int(self.cluster_resources.get("CPU", 4)) - self.num_nodes = len(ray.nodes()) if ray.is_initialized() else 1 + self.num_nodes = len(ray.nodes()) def configure_ray_context(self) -> None: """ @@ -919,7 +943,7 @@ def _create_metadata(self) -> RetrievalMetadata: else: try: result = self._resolve() - if isinstance(result, Dataset): + if is_ray_data(result): timestamp_col = _safe_infer_event_timestamp_column( result, "event_timestamp" ) @@ -964,11 +988,12 @@ def _get_ray_dataset(self) -> Dataset: return self._cached_dataset result = self._resolve() - if isinstance(result, Dataset): + if is_ray_data(result): self._cached_dataset = result return result elif isinstance(result, pd.DataFrame): - self._cached_dataset = ray.data.from_pandas(result) + ray_wrapper = get_ray_wrapper() + self._cached_dataset = ray_wrapper.from_pandas(result) return self._cached_dataset else: raise ValueError(f"Unsupported result type: {type(result)}") @@ -1225,82 +1250,13 @@ def _suppress_ray_logging() -> None: @staticmethod def _ensure_ray_initialized(config: Optional[RepoConfig] = None) -> None: """Ensure Ray is initialized with proper configuration.""" - ray_config = None - if config and hasattr(config, "offline_store"): - ray_config = config.offline_store - if isinstance(ray_config, RayOfflineStoreConfig): - if not ray_config.enable_ray_logging: - RayOfflineStore._suppress_ray_logging() - - if not ray.is_initialized(): - ray_init_kwargs: Dict[str, Any] = { - "ignore_reinit_error": True, - "include_dashboard": False, - } - - if ( - ray_config - and isinstance(ray_config, RayOfflineStoreConfig) - and not ray_config.enable_ray_logging - ): - ray_init_kwargs.update( - { - "log_to_driver": False, - "logging_level": "ERROR", - } - ) - - if config and hasattr(config, "offline_store"): - if isinstance(ray_config, RayOfflineStoreConfig): - if ray_config.ray_address: - ray_init_kwargs["address"] = ray_config.ray_address - else: - ray_init_kwargs.update( - { - "_node_ip_address": os.getenv( - "RAY_NODE_IP", "127.0.0.1" - ), - "num_cpus": os.cpu_count() or 4, - } - ) - - if ray_config.ray_conf: - ray_init_kwargs.update(ray_config.ray_conf) - else: - pass # Use default initialization - - ray.init(**ray_init_kwargs) - - ctx = DatasetContext.get_current() - ctx.shuffle_strategy = "sort" # type: ignore - ctx.enable_tensor_extension_casting = False - - if ( - ray_config - and isinstance(ray_config, RayOfflineStoreConfig) - and not ray_config.enable_ray_logging - ): - RayOfflineStore._suppress_ray_logging() - - if ray.is_initialized(): - cluster_resources = ray.cluster_resources() - if ( - not ray_config - or not isinstance(ray_config, RayOfflineStoreConfig) - or ray_config.enable_ray_logging - ): - logger.info( - f"Ray cluster initialized with {cluster_resources.get('CPU', 0)} CPUs, " - f"{cluster_resources.get('memory', 0) / (1024**3):.1f}GB memory" - ) + ensure_ray_initialized(config) def _init_ray(self, config: RepoConfig) -> None: ray_config = config.offline_store assert isinstance(ray_config, RayOfflineStoreConfig) - RayOfflineStore._ensure_ray_initialized(config) - if not ray_config.enable_ray_logging: - RayOfflineStore._suppress_ray_logging() + RayOfflineStore._ensure_ray_initialized(config) if self._resource_manager is None: self._resource_manager = RayResourceManager(ray_config) @@ -1378,12 +1334,13 @@ def offline_write_batch( batch_source_path = feature_view.batch_source.file_options.uri feature_path = FileSource.get_uri_for_file_path(repo_path, batch_source_path) - ds = ray.data.from_arrow(table) + ray_wrapper = get_ray_wrapper() + ds = ray_wrapper.from_arrow(table) try: if feature_path.endswith(".parquet"): if os.path.exists(feature_path): - existing_ds = ray.data.read_parquet(feature_path) + existing_ds = ray_wrapper.read_parquet(feature_path) combined_ds = existing_ds.union(ds) combined_ds.write_parquet(feature_path) else: @@ -1408,7 +1365,7 @@ def offline_write_batch( df.to_parquet(feature_path, index=False) else: os.makedirs(feature_path, exist_ok=True) - ds_fallback = ray.data.from_pandas(df) + ds_fallback = ray_wrapper.from_pandas(df) ds_fallback.write_parquet(feature_path) if progress: @@ -1776,10 +1733,11 @@ def write_logged_features( absolute_path = FileSource.get_uri_for_file_path(repo_path, destination.path) try: + ray_wrapper = get_ray_wrapper() if isinstance(data, Path): - ds = ray.data.read_parquet(str(data)) + ds = ray_wrapper.read_parquet(str(data)) else: - ds = ray.data.from_arrow(data) + ds = ray_wrapper.from_arrow(data) # Normalize feature timestamp precision to seconds to match test expectations during write # Note: Don't normalize __log_timestamp as it's used for time range filtering @@ -1831,7 +1789,8 @@ def _create_filtered_dataset( end_date: Optional[datetime] = None, ) -> Dataset: """Helper method to create a filtered dataset based on timestamp range.""" - ds = ray.data.read_parquet(source_path) + ray_wrapper = get_ray_wrapper() + ds = ray_wrapper.read_parquet(source_path) try: col_names = ds.schema().names @@ -1888,11 +1847,12 @@ def get_historical_features( store._init_ray(config) # Load entity_df as Ray dataset for distributed processing + ray_wrapper = get_ray_wrapper() if isinstance(entity_df, str): - entity_ds = ray.data.read_csv(entity_df) + entity_ds = ray_wrapper.read_csv(entity_df) entity_df_sample = entity_ds.limit(1000).to_pandas() else: - entity_ds = ray.data.from_pandas(entity_df) + entity_ds = ray_wrapper.from_pandas(entity_df) entity_df_sample = entity_df.copy() entity_ds = ensure_timestamp_compatibility(entity_ds, ["event_timestamp"]) @@ -1970,7 +1930,7 @@ def get_historical_features( # Read from the resolved data source source_path = store._get_source_path(source_info.data_source, config) - feature_ds = ray.data.read_parquet(source_path) + feature_ds = ray_wrapper.read_parquet(source_path) logger.info( f"Reading feature view {fv.name}: {source_info.source_description}" ) diff --git a/sdk/python/feast/infra/ray_initializer.py b/sdk/python/feast/infra/ray_initializer.py new file mode 100644 index 00000000000..f28d8d37aed --- /dev/null +++ b/sdk/python/feast/infra/ray_initializer.py @@ -0,0 +1,660 @@ +# Copyright 2025 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Centralized Ray Initialization Module for Feast. + +This module combines configuration management and initialization logic for a +complete, self-contained Ray setup system. +""" + +import logging +import os +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +import ray +from ray.data.context import DatasetContext + +logger = logging.getLogger(__name__) + + +class RayExecutionMode(Enum): + """Ray execution modes supported by Feast.""" + + LOCAL = "local" + REMOTE = "remote" + KUBERAY = "kuberay" + + +class RayConfigManager: + """ + Manages Ray configuration and execution mode determination. + + Supports three main scenarios: + 1. Local Ray: Single-machine development and testing + 2. Remote Ray: Connect to existing Ray standalone cluster + 3. KubeRay: Ray on Kubernetes with CodeFlare SDK + + The manager determines execution mode based on configuration precedence: + 1. Environment variable FEAST_RAY_EXECUTION_MODE (highest) + 2. KubeRay mode (use_kuberay=True or cluster_name specified) + 3. Remote mode (ray_address specified) + 4. Local mode (default fallback) + """ + + def __init__(self, config: Optional[Union[Dict[str, Any], object]] = None): + """ + Initialize Ray configuration manager. + + Args: + config: Ray configuration (RayOfflineStoreConfig, RayComputeEngineConfig, or dict) + """ + self.config = config or {} + self._execution_mode: Optional[RayExecutionMode] = None + self._codeflare_config: Optional[Dict[str, Any]] = None + + def determine_execution_mode(self) -> RayExecutionMode: + """ + Determine the appropriate Ray execution mode based on configuration. + + Precedence (highest to lowest): + 1. Environment variable FEAST_RAY_EXECUTION_MODE (explicit override) + 2. KubeRay mode (use_kuberay=True or cluster_name specified) + 3. Remote mode (ray_address specified) + 4. Local mode (default fallback) + + Returns: + RayExecutionMode enum value + """ + if self._execution_mode is not None: + return self._execution_mode + + # 1. Check environment variable override first (highest precedence) + env_mode = os.getenv("FEAST_RAY_EXECUTION_MODE", "").lower() + if env_mode in ["local", "remote", "kuberay"]: + self._execution_mode = RayExecutionMode(env_mode) + logger.info( + f"Ray execution mode set via FEAST_RAY_EXECUTION_MODE: {env_mode}" + ) + return self._execution_mode + + # 2. Check for KubeRay configuration (second highest precedence) + use_kuberay = self._get_config_value("use_kuberay") + + # Check for cluster_name in kuberay_conf + kuberay_conf = self._get_config_value("kuberay_conf", {}) or {} + cluster_name = kuberay_conf.get("cluster_name") + + # Environment variables can enable KubeRay + if os.getenv("FEAST_USE_KUBERAY", "").lower() == "true": + use_kuberay = True + if os.getenv("FEAST_RAY_CLUSTER_NAME"): + cluster_name = os.getenv("FEAST_RAY_CLUSTER_NAME") + + # KubeRay takes precedence over remote/local if configured + if use_kuberay or cluster_name: + self._execution_mode = RayExecutionMode.KUBERAY + reason = [] + if use_kuberay: + reason.append("use_kuberay=True") + if cluster_name: + reason.append(f"cluster_name='{cluster_name}'") + logger.info(f"Ray execution mode: KubeRay ({', '.join(reason)})") + return self._execution_mode + + # 3. Check for remote Ray configuration (third precedence) + ray_address = self._get_config_value("ray_address") or os.getenv("RAY_ADDRESS") + + if ray_address: + self._execution_mode = RayExecutionMode.REMOTE + logger.info(f"Ray execution mode: Remote (ray_address='{ray_address}')") + return self._execution_mode + + # 4. Default to local Ray (lowest precedence - fallback) + self._execution_mode = RayExecutionMode.LOCAL + logger.info( + "Ray execution mode: Local (default - no KubeRay or remote configuration found)" + ) + return self._execution_mode + + def get_kuberay_config(self) -> Dict[str, Any]: + """ + Get KubeRay/CodeFlare SDK configuration. + + Returns: + Dictionary of KubeRay configuration with passthrough settings + """ + if self._codeflare_config is not None: + return self._codeflare_config + + # Get passthrough configuration from kuberay_conf first + kuberay_conf = self._get_config_value("kuberay_conf", {}) or {} + + config = { + "use_kuberay": ( + os.getenv("FEAST_USE_KUBERAY", "").lower() == "true" + or self._get_config_value("use_kuberay", False) + ), + # Get values from kuberay_conf or environment variables + "cluster_name": ( + os.getenv("FEAST_RAY_CLUSTER_NAME") or kuberay_conf.get("cluster_name") + ), + "namespace": ( + os.getenv("FEAST_RAY_NAMESPACE") + or kuberay_conf.get("namespace", "default") + ), + } + + # Add authentication configuration from kuberay_conf or environment variables + auth_token = ( + os.getenv("FEAST_RAY_AUTH_TOKEN") + or os.getenv("RAY_AUTH_TOKEN") + or kuberay_conf.get("auth_token") + ) + if auth_token: + config["auth_token"] = auth_token + + # Add authentication server URL + auth_server = ( + os.getenv("FEAST_RAY_AUTH_SERVER") + or os.getenv("RAY_AUTH_SERVER") + or kuberay_conf.get("auth_server") + ) + if auth_server: + config["auth_server"] = auth_server + + # Add skip TLS verification setting + skip_tls = os.getenv( + "FEAST_RAY_SKIP_TLS", "" + ).lower() == "true" or kuberay_conf.get("skip_tls", False) + config["skip_tls"] = skip_tls + + # Add any additional configuration from kuberay_conf + for key, value in kuberay_conf.items(): + if key not in config: # Don't override already processed keys + config[key] = value + + self._codeflare_config = config + return config + + def _get_config_value(self, key: str, default: Any = None) -> Any: + """ + Get configuration value from config object or dictionary. + + Args: + key: Configuration key + default: Default value if key not found + + Returns: + Configuration value + """ + if hasattr(self.config, key): + return getattr(self.config, key) + elif isinstance(self.config, dict): + return self.config.get(key, default) + else: + return default + + +class StandardRayWrapper: + """Wrapper for Ray Native operations.""" + + def read_parquet(self, path: Union[str, List[str]]) -> Any: + """Read parquet files using standard Ray.""" + return ray.data.read_parquet(path) + + def read_csv(self, path: Union[str, List[str]]) -> Any: + """Read CSV files using standard Ray.""" + return ray.data.read_csv(path) + + def from_pandas(self, df: Any) -> Any: + """Create dataset from pandas DataFrame using standard Ray.""" + return ray.data.from_pandas(df) + + def from_arrow(self, table: Any) -> Any: + """Create dataset from Arrow table using standard Ray.""" + return ray.data.from_arrow(table) + + +class CodeFlareRayWrapper: + """Wrapper for Ray operations on KubeRay clusters using CodeFlare SDK.""" + + def __init__( + self, + cluster_name: str, + namespace: str, + auth_token: str, + auth_server: str, + skip_tls: bool = False, + enable_logging: bool = False, + ): + """Initialize CodeFlare Ray wrapper with cluster connection parameters.""" + self.cluster_name = cluster_name + self.namespace = namespace + self.auth_token = auth_token + self.auth_server = auth_server + self.skip_tls = skip_tls + self.enable_logging = enable_logging + self.cluster = None + + # Authenticate and setup Ray connection + self._authenticate_codeflare() + self._setup_ray_connection() + + def _authenticate_codeflare(self): + """Authenticate with CodeFlare SDK.""" + try: + from codeflare_sdk import TokenAuthentication + + auth = TokenAuthentication( + token=self.auth_token, + server=self.auth_server, + skip_tls=self.skip_tls, + ) + auth.login() + except Exception as e: + logger.error(f"CodeFlare authentication failed: {e}") + raise + + def _setup_ray_connection(self): + """Setup Ray connection to KubeRay cluster using TLS certificates.""" + try: + from codeflare_sdk import generate_cert, get_cluster + + self.cluster = get_cluster( + cluster_name=self.cluster_name, namespace=self.namespace + ) + if self.cluster is None: + raise RuntimeError( + f"Failed to find KubeRay cluster '{self.cluster_name}' in namespace '{self.namespace}'" + ) + generate_cert.generate_tls_cert(self.cluster_name, self.namespace) + generate_cert.export_env(self.cluster_name, self.namespace) + + cluster_uri = self.cluster.cluster_uri() + runtime_env = { + "pip": ["feast"], + "env_vars": {"RAY_DISABLE_IMPORT_WARNING": "1"}, + } + + ray.shutdown() + + logging_level = "INFO" if self.enable_logging else "ERROR" + + ray.init( + address=cluster_uri, + ignore_reinit_error=True, + logging_level=logging_level, + log_to_driver=self.enable_logging, + runtime_env=runtime_env, + ) + + logger.info(f"Ray connected successfully to cluster: {self.cluster_name}") + + except Exception as e: + logger.error(f"Ray connection failed: {e}") + raise + + # Ray Data API methods - wrapped in @ray.remote to execute on cluster workers + def read_parquet(self, path: Union[str, List[str]]) -> Any: + """Read parquet files - runs remotely on KubeRay cluster workers.""" + from feast.infra.ray_shared_utils import RemoteDatasetProxy + + @ray.remote + def _remote_read_parquet(file_path): + import ray + + return ray.data.read_parquet(file_path) + + return RemoteDatasetProxy(_remote_read_parquet.remote(path)) + + def read_csv(self, path: Union[str, List[str]]) -> Any: + """Read CSV files - runs remotely on KubeRay cluster workers.""" + from feast.infra.ray_shared_utils import RemoteDatasetProxy + + @ray.remote + def _remote_read_csv(file_path): + import ray + + return ray.data.read_csv(file_path) + + return RemoteDatasetProxy(_remote_read_csv.remote(path)) + + def from_pandas(self, df: Any) -> Any: + """Create dataset from pandas DataFrame - runs remotely on KubeRay cluster workers.""" + from feast.infra.ray_shared_utils import RemoteDatasetProxy + + @ray.remote + def _remote_from_pandas(dataframe): + import ray + + return ray.data.from_pandas(dataframe) + + return RemoteDatasetProxy(_remote_from_pandas.remote(df)) + + def from_arrow(self, table: Any) -> Any: + """Create dataset from Arrow table - runs remotely on KubeRay cluster workers.""" + from feast.infra.ray_shared_utils import RemoteDatasetProxy + + @ray.remote + def _remote_from_arrow(arrow_table): + import ray + + return ray.data.from_arrow(arrow_table) + + return RemoteDatasetProxy(_remote_from_arrow.remote(table)) + + +# Global state tracking +_ray_initialized = False +_ray_wrapper: Optional[Union[StandardRayWrapper, CodeFlareRayWrapper]] = None + + +def _suppress_ray_logging() -> None: + """Suppress Ray and Ray Data logging completely.""" + import warnings + + # Suppress Ray warnings + warnings.filterwarnings("ignore", category=DeprecationWarning, module="ray") + warnings.filterwarnings("ignore", category=UserWarning, module="ray") + + # Set environment variables to suppress Ray output + os.environ["RAY_DISABLE_IMPORT_WARNING"] = "1" + os.environ["RAY_SUPPRESS_UNVERIFIED_TLS_WARNING"] = "1" + os.environ["RAY_LOG_LEVEL"] = "ERROR" + os.environ["RAY_DATA_LOG_LEVEL"] = "ERROR" + os.environ["RAY_DISABLE_PROGRESS_BARS"] = "1" + + # Suppress all Ray-related loggers + ray_loggers = [ + "ray", + "ray.data", + "ray.data.dataset", + "ray.data.context", + "ray.data._internal.streaming_executor", + "ray.data._internal.execution", + "ray.data._internal", + "ray.tune", + "ray.serve", + "ray.util", + "ray._private", + ] + for logger_name in ray_loggers: + logging.getLogger(logger_name).setLevel(logging.ERROR) + + # Configure DatasetContext to disable progress bars + try: + ctx = DatasetContext.get_current() + ctx.enable_progress_bars = False + if hasattr(ctx, "verbose_progress"): + ctx.verbose_progress = False + except Exception: + pass # Ignore if Ray Data is not available + + +def _initialize_local_ray(config: Any, enable_logging: bool = False) -> None: + """ + Initialize Ray in local mode. + + Args: + config: Configuration object (RayOfflineStoreConfig or RayComputeEngineConfig) + enable_logging: Whether to enable Ray logging + """ + logger.info("Initializing Ray in LOCAL mode") + + ray_init_kwargs: Dict[str, Any] = { + "ignore_reinit_error": True, + "include_dashboard": False, + } + + if enable_logging: + ray_init_kwargs.update( + { + "log_to_driver": True, + "logging_level": "INFO", + } + ) + else: + ray_init_kwargs.update( + { + "log_to_driver": False, + "logging_level": "ERROR", + } + ) + _suppress_ray_logging() + + # Add local configuration + ray_init_kwargs.update( + { + "_node_ip_address": os.getenv("RAY_NODE_IP", "127.0.0.1"), + "num_cpus": os.cpu_count() or 4, + } + ) + + # Merge with user-provided ray_conf if available + if hasattr(config, "ray_conf") and config.ray_conf: + ray_init_kwargs.update(config.ray_conf) + + # Initialize Ray + ray.init(**ray_init_kwargs) + + # Configure DatasetContext + ctx = DatasetContext.get_current() + ctx.shuffle_strategy = "sort" # type: ignore + ctx.enable_tensor_extension_casting = False + + # Log cluster info + if enable_logging: + cluster_resources = ray.cluster_resources() + logger.info( + f"Ray local cluster initialized with {cluster_resources.get('CPU', 0)} CPUs, " + f"{cluster_resources.get('memory', 0) / (1024**3):.1f}GB memory" + ) + + +def _initialize_remote_ray(config: Any, enable_logging: bool = False) -> None: + """ + Initialize Ray in remote mode (connect to existing Ray cluster). + + Args: + config: Configuration object with ray_address + enable_logging: Whether to enable Ray logging + """ + ray_address = getattr(config, "ray_address", None) + if not ray_address: + ray_address = os.getenv("RAY_ADDRESS") + + if not ray_address: + raise ValueError("ray_address must be specified for remote Ray mode") + + logger.info(f"Initializing Ray in REMOTE mode, connecting to: {ray_address}") + + ray_init_kwargs: Dict[str, Any] = { + "address": ray_address, + "ignore_reinit_error": True, + "include_dashboard": False, + } + + if enable_logging: + ray_init_kwargs.update( + { + "log_to_driver": True, + "logging_level": "INFO", + } + ) + else: + ray_init_kwargs.update( + { + "log_to_driver": False, + "logging_level": "ERROR", + } + ) + _suppress_ray_logging() + + # Merge with user-provided ray_conf if available + if hasattr(config, "ray_conf") and config.ray_conf: + ray_init_kwargs.update(config.ray_conf) + + # Initialize Ray + ray.init(**ray_init_kwargs) + + # Configure DatasetContext + ctx = DatasetContext.get_current() + ctx.shuffle_strategy = "sort" # type: ignore + ctx.enable_tensor_extension_casting = False + + # Log cluster info + if enable_logging: + cluster_resources = ray.cluster_resources() + logger.info( + f"Ray remote cluster initialized with {cluster_resources.get('CPU', 0)} CPUs, " + f"{cluster_resources.get('memory', 0) / (1024**3):.1f}GB memory" + ) + + +def _initialize_kuberay(config: Any, enable_logging: bool = False) -> None: + """ + Initialize Ray in KubeRay mode using CodeFlare SDK. + + Args: + config: Configuration object with KubeRay settings + enable_logging: Whether to enable Ray logging + """ + global _ray_wrapper + + logger.info("Initializing Ray in KUBERAY mode using CodeFlare SDK") + + if not enable_logging: + _suppress_ray_logging() + + # Get KubeRay configuration + config_manager = RayConfigManager(config) + kuberay_config = config_manager.get_kuberay_config() + + # Initialize CodeFlare Ray wrapper - this connects to the cluster + _ray_wrapper = CodeFlareRayWrapper( + cluster_name=kuberay_config["cluster_name"], + namespace=kuberay_config["namespace"], + auth_token=kuberay_config["auth_token"], + auth_server=kuberay_config["auth_server"], + skip_tls=kuberay_config.get("skip_tls", False), + enable_logging=enable_logging, + ) + + logger.info("KubeRay cluster connection established via CodeFlare SDK") + + +def ensure_ray_initialized( + config: Optional[Any] = None, force_reinit: bool = False +) -> None: + """ + Ensure Ray is initialized with appropriate configuration. + + This is the main entry point for Ray initialization across all Feast components. + It automatically detects the execution mode and initializes Ray accordingly. + + Args: + config: Configuration object (RayOfflineStoreConfig, RayComputeEngineConfig, or RepoConfig) + force_reinit: If True, reinitialize Ray even if already initialized + + Raises: + ValueError: If configuration is invalid or required parameters are missing + """ + global _ray_initialized + + # Check if already initialized + if _ray_initialized and not force_reinit: + logger.debug("Ray already initialized, skipping initialization") + return + + # Extract Ray-specific config if RepoConfig is provided + ray_config = config + if config and hasattr(config, "offline_store"): + ray_config = config.offline_store + elif config and hasattr(config, "batch_engine"): + ray_config = config.batch_engine + + # Determine enable_logging setting + enable_logging = ( + getattr(ray_config, "enable_ray_logging", False) if ray_config else False + ) + + # Use RayConfigManager to determine execution mode + config_manager = RayConfigManager(ray_config) + execution_mode = config_manager.determine_execution_mode() + + logger.info(f"Ray execution mode detected: {execution_mode.value}") + + # Check if Ray is already initialized (from external source) + if ray.is_initialized() and not force_reinit: + logger.info("Ray is already initialized externally, using existing cluster") + # Configure DatasetContext even if Ray is already initialized + ctx = DatasetContext.get_current() + ctx.shuffle_strategy = "sort" # type: ignore + ctx.enable_tensor_extension_casting = False + if not enable_logging: + _suppress_ray_logging() + _ray_initialized = True + return + + # Initialize based on execution mode + try: + if execution_mode == RayExecutionMode.KUBERAY: + _initialize_kuberay(ray_config, enable_logging) + elif execution_mode == RayExecutionMode.REMOTE: + _initialize_remote_ray(ray_config, enable_logging) + else: # LOCAL + _initialize_local_ray(ray_config, enable_logging) + + _ray_initialized = True + logger.info(f"Ray initialized successfully in {execution_mode.value} mode") + + except Exception as e: + logger.error(f"Failed to initialize Ray in {execution_mode.value} mode: {e}") + raise + + +def get_ray_wrapper() -> Union[StandardRayWrapper, CodeFlareRayWrapper]: + """ + Get the appropriate Ray wrapper based on current initialization mode. + + Returns: + StandardRayWrapper for local/remote modes, CodeFlareRayWrapper for KubeRay mode + """ + global _ray_wrapper + + if _ray_wrapper is None: + # Return a standard Ray wrapper for local/remote modes + _ray_wrapper = StandardRayWrapper() + + return _ray_wrapper + + +def is_ray_initialized() -> bool: + """Check if Ray has been initialized via this module.""" + return _ray_initialized + + +def shutdown_ray() -> None: + """Shutdown Ray and reset initialization state.""" + global _ray_initialized, _ray_wrapper + + if ray.is_initialized(): + logger.info("Shutting down Ray") + ray.shutdown() + + _ray_initialized = False + _ray_wrapper = None + logger.info("Ray shutdown complete") diff --git a/sdk/python/feast/infra/ray_shared_utils.py b/sdk/python/feast/infra/ray_shared_utils.py index 9e9254fbfae..9614623294f 100644 --- a/sdk/python/feast/infra/ray_shared_utils.py +++ b/sdk/python/feast/infra/ray_shared_utils.py @@ -49,7 +49,12 @@ def to_arrow(self) -> pa.Table: @ray.remote def _remote_to_arrow(dataset): - return dataset.to_arrow() + arrow_refs = dataset.to_arrow_refs() + if arrow_refs: + tables = ray.get(arrow_refs) + return pa.concat_tables(tables) + else: + return pa.Table.from_pydict({}) result_ref = _remote_to_arrow.remote(self._dataset_ref) return ray.get(result_ref) @@ -124,6 +129,16 @@ def _remote_take(dataset, num): result_ref = _remote_take.remote(self._dataset_ref, n) return ray.get(result_ref) + def size_bytes(self) -> int: + """Execute size_bytes remotely and return result.""" + + @ray.remote + def _remote_size_bytes(dataset): + return dataset.size_bytes() + + result_ref = _remote_size_bytes.remote(self._dataset_ref) + return ray.get(result_ref) + def __getattr__(self, name): """Catch any method calls that we haven't explicitly implemented.""" raise AttributeError(f"RemoteDatasetProxy has no attribute '{name}'") diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/ray_shared_utils.py b/sdk/python/tests/integration/compute_engines/ray_compute/ray_shared_utils.py index 9e9aabc4f90..6b28949f401 100644 --- a/sdk/python/tests/integration/compute_engines/ray_compute/ray_shared_utils.py +++ b/sdk/python/tests/integration/compute_engines/ray_compute/ray_shared_utils.py @@ -9,10 +9,10 @@ import pandas as pd import pytest -import ray from feast import Entity, FileSource from feast.data_source import DataSource +from feast.infra.ray_initializer import shutdown_ray from feast.utils import _utc_now from tests.integration.feature_repos.repo_configuration import ( construct_test_environment, @@ -126,8 +126,7 @@ def cleanup_ray_environment(ray_environment): # Ensure Ray is shut down completely try: - if ray.is_initialized(): - ray.shutdown() + shutdown_ray() time.sleep(0.2) # Brief pause to ensure clean shutdown except Exception as e: print(f"Warning: Ray shutdown failed: {e}") @@ -147,9 +146,8 @@ def create_ray_environment(): def ray_environment() -> Generator: """Pytest fixture to provide a Ray environment for tests with automatic cleanup.""" try: - if ray.is_initialized(): - ray.shutdown() - time.sleep(0.2) + shutdown_ray() + time.sleep(0.2) except Exception: pass diff --git a/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py b/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py index e7060b4a756..d6b06c9c080 100644 --- a/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py +++ b/sdk/python/tests/integration/compute_engines/ray_compute/test_compute.py @@ -162,7 +162,6 @@ def test_ray_compute_engine_config(): window_size_for_joins="2H", max_workers=4, enable_optimization=True, - execution_timeout_seconds=3600, ) assert config.type == "ray.engine" diff --git a/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py index e8c40d43099..0da4fcb956e 100644 --- a/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/ray_compute/test_nodes.py @@ -1,4 +1,6 @@ +import os from datetime import datetime, timedelta +from unittest.mock import patch import pandas as pd import pytest @@ -18,6 +20,12 @@ RayReadNode, RayTransformationNode, ) +from feast.infra.ray_initializer import ( + RayConfigManager, + RayExecutionMode, + ensure_ray_initialized, + get_ray_wrapper, +) class DummyInputNode(DAGNode): @@ -317,3 +325,101 @@ def test_ray_config_validation(): # Test invalid window size defaults to 1 hour config_invalid = RayComputeEngineConfig(window_size_for_joins="invalid") assert config_invalid.window_size_timedelta == timedelta(hours=1) + + +def test_ray_initialization_and_kuberay_modes(): + """ + Comprehensive test for Ray initialization modes and KubeRay configuration. + + Tests: Mode detection (LOCAL/REMOTE/KUBERAY), config parsing, defaults, + environment variables, mode precedence, and Ray wrapper instantiation. + """ + # Test LOCAL mode (default) + config_local = RayComputeEngineConfig() + assert ( + RayConfigManager(config_local).determine_execution_mode() + == RayExecutionMode.LOCAL + ) + + # Test REMOTE mode + config_remote = RayComputeEngineConfig(ray_address="ray://localhost:10001") + manager_remote = RayConfigManager(config_remote) + assert manager_remote.determine_execution_mode() == RayExecutionMode.REMOTE + # Test execution mode caching + assert manager_remote.determine_execution_mode() == RayExecutionMode.REMOTE + + # Test KUBERAY mode with full config + config_kuberay = RayComputeEngineConfig( + use_kuberay=True, + kuberay_conf={ + "cluster_name": "feast-cluster", + "namespace": "feast-system", + "auth_token": "test-token", + "auth_server": "https://api.example.com", + "skip_tls": True, + }, + ) + manager_kuberay = RayConfigManager(config_kuberay) + assert manager_kuberay.determine_execution_mode() == RayExecutionMode.KUBERAY + kuberay_config = manager_kuberay.get_kuberay_config() + assert kuberay_config["cluster_name"] == "feast-cluster" + assert kuberay_config["namespace"] == "feast-system" + assert kuberay_config["auth_token"] == "test-token" + assert kuberay_config["skip_tls"] is True + + # Test KubeRay defaults + config_defaults = RayComputeEngineConfig( + use_kuberay=True, kuberay_conf={"cluster_name": "test-cluster"} + ) + defaults_config = RayConfigManager(config_defaults).get_kuberay_config() + assert defaults_config["namespace"] == "default" + assert defaults_config["skip_tls"] is False + + # Test mode precedence - KUBERAY overrides REMOTE + config_precedence = RayComputeEngineConfig( + ray_address="ray://localhost:10001", + use_kuberay=True, + kuberay_conf={"cluster_name": "test-cluster"}, + ) + assert ( + RayConfigManager(config_precedence).determine_execution_mode() + == RayExecutionMode.KUBERAY + ) + + # Test environment variable support + with patch.dict( + os.environ, + { + "FEAST_RAY_CLUSTER_NAME": "env-cluster", + "FEAST_RAY_NAMESPACE": "env-namespace", + "FEAST_RAY_AUTH_TOKEN": "env-token", + }, + ): + env_config = RayConfigManager( + RayComputeEngineConfig(use_kuberay=True, kuberay_conf={}) + ).get_kuberay_config() + assert env_config["cluster_name"] == "env-cluster" + assert env_config["namespace"] == "env-namespace" + assert env_config["auth_token"] == "env-token" + + # Test Ray wrapper instantiation + from feast.infra.ray_initializer import StandardRayWrapper + + wrapper = get_ray_wrapper() + assert isinstance(wrapper, StandardRayWrapper) + + config_custom = RayComputeEngineConfig( + enable_ray_logging=True, + max_workers=4, + broadcast_join_threshold_mb=200, + ray_conf={"num_cpus": 4}, + ) + assert config_custom.enable_ray_logging is True + assert config_custom.max_workers == 4 + assert config_custom.broadcast_join_threshold_mb == 200 + assert config_custom.ray_conf["num_cpus"] == 4 + + with patch("feast.infra.ray_initializer.ray") as mock_ray: + mock_ray.is_initialized.return_value = True + ensure_ray_initialized(config_local) + mock_ray.init.assert_not_called()