diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 70abce32dfe..d82d882ff26 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -16,6 +16,16 @@ jobs: env: OS: ${{ matrix.os }} PYTHON: ${{ matrix.python-version }} + services: + redis: + image: redis + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 steps: - uses: actions/checkout@v2 - name: Setup Python @@ -51,4 +61,4 @@ jobs: flags: integrationtests env_vars: OS,PYTHON fail_ci_if_error: true - verbose: true \ No newline at end of file + verbose: true diff --git a/.github/workflows/pr_integration_tests.yml b/.github/workflows/pr_integration_tests.yml index 8172c62b30d..240d87069ec 100644 --- a/.github/workflows/pr_integration_tests.yml +++ b/.github/workflows/pr_integration_tests.yml @@ -18,6 +18,9 @@ jobs: matrix: python-version: [ 3.7, 3.8, 3.9 ] os: [ ubuntu-latest ] + env: + OS: ${{ matrix.os }} + PYTHON: ${{ matrix.python-version }} services: redis: image: redis @@ -28,9 +31,6 @@ jobs: --health-interval 10s --health-timeout 5s --health-retries 5 - env: - OS: ${{ matrix.os }} - PYTHON: ${{ matrix.python-version }} steps: - uses: actions/checkout@v2 with: @@ -64,9 +64,6 @@ jobs: run: make install-python-ci-dependencies - name: Test python run: FEAST_TELEMETRY=False pytest --cov=./ --cov-report=xml --verbose --color=yes sdk/python/tests --integration - env: - REDIS_TYPE: REDIS - REDIS_CONNECTION_STRING: localhost:6379,db=0 - name: Upload coverage to Codecov uses: codecov/codecov-action@v1 with: diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 89825d82057..5e5a1f5ec18 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -541,7 +541,10 @@ def get_online_features( table, union_of_entity_keys, entity_name_to_join_key_map ) read_rows = provider.online_read( - project=self.project, table=table, entity_keys=entity_keys, + project=self.project, + table=table, + entity_keys=entity_keys, + requested_features=requested_features, ) for row_idx, read_row in enumerate(read_rows): row_ts, feature_data = read_row diff --git a/sdk/python/feast/infra/gcp.py b/sdk/python/feast/infra/gcp.py index 7299dd99f1a..dce073e1d39 100644 --- a/sdk/python/feast/infra/gcp.py +++ b/sdk/python/feast/infra/gcp.py @@ -128,6 +128,7 @@ def online_read( project: str, table: Union[FeatureTable, FeatureView], entity_keys: List[EntityKeyProto], + requested_features: List[str] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: client = self._initialize_client() diff --git a/sdk/python/feast/infra/local.py b/sdk/python/feast/infra/local.py index f9516dc7819..3827e6aea60 100644 --- a/sdk/python/feast/infra/local.py +++ b/sdk/python/feast/infra/local.py @@ -131,6 +131,7 @@ def online_read( project: str, table: Union[FeatureTable, FeatureView], entity_keys: List[EntityKeyProto], + requested_features: List[str] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: conn = self._get_conn() diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 05dac141c85..6a378d0759c 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -125,6 +125,7 @@ def online_read( project: str, table: Union[FeatureTable, FeatureView], entity_keys: List[EntityKeyProto], + requested_features: List[str] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: """ Read feature values given an Entity Key. This is a low level interface, not @@ -144,6 +145,10 @@ def get_provider(config: RepoConfig, repo_path: Path) -> Provider: from feast.infra.gcp import GcpProvider return GcpProvider(config) + elif config.provider == "redis": + from feast.infra.redis import RedisProvider + + return RedisProvider(config) elif config.provider == "local": from feast.infra.local import LocalProvider diff --git a/sdk/python/feast/infra/redis.py b/sdk/python/feast/infra/redis.py new file mode 100644 index 00000000000..f4200918b3f --- /dev/null +++ b/sdk/python/feast/infra/redis.py @@ -0,0 +1,281 @@ +import json +import struct +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import mmh3 +import pandas as pd +from google.protobuf.timestamp_pb2 import Timestamp + +try: + from redis import Redis + from rediscluster import RedisCluster +except ImportError as e: + from feast.errors import FeastExtrasDependencyImportError + + raise FeastExtrasDependencyImportError("redis", str(e)) + +from tqdm import tqdm + +from feast import FeatureTable, utils +from feast.entity import Entity +from feast.feature_view import FeatureView +from feast.infra.offline_stores.helpers import get_offline_store_from_config +from feast.infra.provider import ( + Provider, + RetrievalJob, + _convert_arrow_to_proto, + _get_column_names, + _run_field_mapping, +) +from feast.protos.feast.storage.Redis_pb2 import RedisKeyV2 as RedisKeyProto +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.registry import Registry +from feast.repo_config import RedisOnlineStoreConfig, RedisType, RepoConfig + +EX_SECONDS = 253402300799 + + +class RedisProvider(Provider): + _redis_type: Optional[RedisType] + _connection_string: str + + def __init__(self, config: RepoConfig): + assert isinstance(config.online_store, RedisOnlineStoreConfig) + if config.online_store.redis_type: + self._redis_type = config.online_store.redis_type + if config.online_store.connection_string: + self._connection_string = config.online_store.connection_string + self.offline_store = get_offline_store_from_config(config.offline_store) + + def update_infra( + self, + project: str, + tables_to_delete: Sequence[Union[FeatureTable, FeatureView]], + tables_to_keep: Sequence[Union[FeatureTable, FeatureView]], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, + ): + pass + + def teardown_infra( + self, + project: str, + tables: Sequence[Union[FeatureTable, FeatureView]], + entities: Sequence[Entity], + ) -> None: + # according to the repos_operations.py we can delete the whole project + client = self._get_client() + + tables_join_keys = [[e for e in t.entities] for t in tables] + for table_join_keys in tables_join_keys: + redis_key_bin = _redis_key( + project, EntityKeyProto(join_keys=table_join_keys) + ) + keys = [k for k in client.scan_iter(match=f"{redis_key_bin}*", count=100)] + if keys: + client.unlink(*keys) + + def online_write_batch( + self, + project: str, + table: Union[FeatureTable, FeatureView], + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + client = self._get_client() + + entity_hset = {} + feature_view = table.name + + ex = Timestamp() + ex.seconds = EX_SECONDS + ex_str = ex.SerializeToString() + + for entity_key, values, timestamp, created_ts in data: + redis_key_bin = _redis_key(project, entity_key) + ts = Timestamp() + ts.seconds = int(utils.make_tzaware(timestamp).timestamp()) + entity_hset[f"_ts:{feature_view}"] = ts.SerializeToString() + entity_hset[f"_ex:{feature_view}"] = ex_str + + for feature_name, val in values.items(): + f_key = _mmh3(f"{feature_view}:{feature_name}") + entity_hset[f_key] = val.SerializeToString() + + client.hset(redis_key_bin, mapping=entity_hset) + if progress: + progress(1) + + def online_read( + self, + project: str, + table: Union[FeatureTable, FeatureView], + entity_keys: List[EntityKeyProto], + requested_features: List[str] = None, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + + client = self._get_client() + feature_view = table.name + + result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] + + if not requested_features: + requested_features = [f.name for f in table.features] + + for entity_key in entity_keys: + redis_key_bin = _redis_key(project, entity_key) + hset_keys = [_mmh3(f"{feature_view}:{k}") for k in requested_features] + ts_key = f"_ts:{feature_view}" + hset_keys.append(ts_key) + values = client.hmget(redis_key_bin, hset_keys) + requested_features.append(ts_key) + res_val = dict(zip(requested_features, values)) + + res_ts = Timestamp() + ts_val = res_val.pop(ts_key) + if ts_val: + res_ts.ParseFromString(ts_val) + + res = {} + for feature_name, val_bin in res_val.items(): + val = ValueProto() + if val_bin: + val.ParseFromString(val_bin) + res[feature_name] = val + + if not res: + result.append((None, None)) + else: + timestamp = datetime.fromtimestamp(res_ts.seconds) + result.append((timestamp, res)) + return result + + def materialize_single_feature_view( + self, + feature_view: FeatureView, + start_date: datetime, + end_date: datetime, + registry: Registry, + project: str, + tqdm_builder: Callable[[int], tqdm], + ) -> None: + entities = [] + for entity_name in feature_view.entities: + entities.append(registry.get_entity(entity_name, project)) + + ( + join_key_columns, + feature_name_columns, + event_timestamp_column, + created_timestamp_column, + ) = _get_column_names(feature_view, entities) + + start_date = utils.make_tzaware(start_date) + end_date = utils.make_tzaware(end_date) + + table = self.offline_store.pull_latest_from_table_or_query( + data_source=feature_view.input, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + event_timestamp_column=event_timestamp_column, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + if feature_view.input.field_mapping is not None: + table = _run_field_mapping(table, feature_view.input.field_mapping) + + join_keys = [entity.join_key for entity in entities] + rows_to_write = _convert_arrow_to_proto(table, feature_view, join_keys) + + with tqdm_builder(len(rows_to_write)) as pbar: + self.online_write_batch( + project, feature_view, rows_to_write, lambda x: pbar.update(x) + ) + + feature_view.materialization_intervals.append((start_date, end_date)) + registry.apply_feature_view(feature_view, project) + + def _parse_connection_string(self): + """ + Reads Redis connections string using format + for RedisCluster: + redis1:6379,redis2:6379,decode_responses=true,skip_full_coverage_check=true,ssl=true,password=... + for Redis: + redis_master:6379,db=0,ssl=true,password=... + """ + connection_string = self._connection_string + startup_nodes = [ + dict(zip(["host", "port"], c.split(":"))) + for c in connection_string.split(",") + if "=" not in c + ] + params = {} + for c in connection_string.split(","): + if "=" in c: + kv = c.split("=") + try: + kv[1] = json.loads(kv[1]) + except json.JSONDecodeError: + ... + + it = iter(kv) + params.update(dict(zip(it, it))) + + return startup_nodes, params + + def _get_client(self): + """ + Creates the Redis client RedisCluster or Redis depending on configuration + """ + startup_nodes, kwargs = self._parse_connection_string() + if self._redis_type == RedisType.redis_cluster: + kwargs["startup_nodes"] = startup_nodes + return RedisCluster(**kwargs) + else: + kwargs["host"] = startup_nodes[0]["host"] + kwargs["port"] = startup_nodes[0]["port"] + return Redis(**kwargs) + + def get_historical_features( + self, + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Union[pd.DataFrame, str], + registry: Registry, + project: str, + ) -> RetrievalJob: + return self.offline_store.get_historical_features( + config=config, + feature_views=feature_views, + feature_refs=feature_refs, + entity_df=entity_df, + registry=registry, + project=project, + ) + + +def _redis_key(project: str, entity_key: EntityKeyProto): + redis_key = RedisKeyProto( + project=project, + entity_names=entity_key.join_keys, + entity_values=entity_key.entity_values, + ) + return redis_key.SerializeToString() + + +def _mmh3(key: str): + """ + Calculate murmur3_32 hash which is equal to scala version which is using little endian: + https://stackoverflow.com/questions/29932956/murmur3-hash-different-result-between-python-and-java-implementation + https://stackoverflow.com/questions/13141787/convert-decimal-int-to-little-endian-string-x-x + """ + key_hash = mmh3.hash(key, signed=False) + return bytes.fromhex(struct.pack("=2.1.*", "google-cloud-storage>=1.20.*", "google-cloud-core==1.4.*", + "redis-py-cluster==2.1.2", ] # README file from Feast repo root directory @@ -192,6 +197,7 @@ def run(self): "dev": ["mypy-protobuf==1.*", "grpcio-testing==1.*"], "ci": CI_REQUIRED, "gcp": GCP_REQUIRED, + "redis": REDIS_REQUIRED, }, include_package_data=True, license="Apache", diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index b8705eef99c..b313ad9cc7d 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -70,6 +70,7 @@ def online_read( project: str, table: Union[FeatureTable, FeatureView], entity_keys: List[EntityKeyProto], + requested_features: List[str] = None, ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: pass diff --git a/sdk/python/tests/test_cli_redis.py b/sdk/python/tests/test_cli_redis.py new file mode 100644 index 00000000000..ed330461559 --- /dev/null +++ b/sdk/python/tests/test_cli_redis.py @@ -0,0 +1,60 @@ +import random +import string +import tempfile +from pathlib import Path +from textwrap import dedent + +import pytest + +from feast.feature_store import FeatureStore +from tests.cli_utils import CliRunner +from tests.online_read_write_test import basic_rw_test + + +@pytest.mark.integration +def test_basic() -> None: + project_id = "".join( + random.choice(string.ascii_lowercase + string.digits) for _ in range(10) + ) + runner = CliRunner() + with tempfile.TemporaryDirectory() as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name: + + repo_path = Path(repo_dir_name) + data_path = Path(data_dir_name) + + repo_config = repo_path / "feature_store.yaml" + + repo_config.write_text( + dedent( + f""" + project: {project_id} + registry: {data_path / "registry.db"} + provider: redis + offline_store: + type: bigquery + online_store: + redis_type: redis + connection_string: localhost:6379,db=0 + """ + ) + ) + + repo_example = repo_path / "example.py" + repo_example.write_text( + (Path(__file__).parent / "example_feature_repo_1.py").read_text() + ) + + result = runner.run(["apply"], cwd=repo_path) + assert result.returncode == 0 + + # Doing another apply should be a no op, and should not cause errors + result = runner.run(["apply"], cwd=repo_path) + assert result.returncode == 0 + + basic_rw_test( + FeatureStore(repo_path=str(repo_path), config=None), + view_name="driver_locations", + ) + + result = runner.run(["teardown"], cwd=repo_path) + assert result.returncode == 0 diff --git a/sdk/python/tests/test_offline_online_store_consistency.py b/sdk/python/tests/test_offline_online_store_consistency.py index c39fb427541..b6d2e399e05 100644 --- a/sdk/python/tests/test_offline_online_store_consistency.py +++ b/sdk/python/tests/test_offline_online_store_consistency.py @@ -19,6 +19,8 @@ from feast.feature_view import FeatureView from feast.repo_config import ( DatastoreOnlineStoreConfig, + RedisOnlineStoreConfig, + RedisType, RepoConfig, SqliteOnlineStoreConfig, ) @@ -146,6 +148,42 @@ def prep_local_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]: yield fs, fv +@contextlib.contextmanager +def prep_redis_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]: + with tempfile.NamedTemporaryFile(suffix=".parquet") as f: + df = create_dataset() + f.close() + df.to_parquet(f.name) + file_source = FileSource( + file_format=ParquetFormat(), + file_url=f"file://{f.name}", + event_timestamp_column="ts", + created_timestamp_column="created_ts", + date_partition_column="", + field_mapping={"ts_1": "ts", "id": "driver_id"}, + ) + fv = get_feature_view(file_source) + e = Entity( + name="driver", + description="id for driver", + join_key="driver_id", + value_type=ValueType.INT32, + ) + with tempfile.TemporaryDirectory() as repo_dir_name, tempfile.TemporaryDirectory(): + config = RepoConfig( + registry=str(Path(repo_dir_name) / "registry.db"), + project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}", + provider="redis", + online_store=RedisOnlineStoreConfig( + redis_type=RedisType.redis, connection_string="localhost:6379,db=0", + ), + ) + fs = FeatureStore(config=config) + fs.apply([fv, e]) + + yield fs, fv + + # Checks that both offline & online store values are as expected def check_offline_and_online_features( fs: FeatureStore, @@ -221,6 +259,12 @@ def test_bq_offline_online_store_consistency(bq_source_type: str): run_offline_online_store_consistency_test(fs, fv) +@pytest.mark.integration +def test_redis_offline_online_store_consistency(): + with prep_redis_fs_and_fv() as (fs, fv): + run_offline_online_store_consistency_test(fs, fv) + + def test_local_offline_online_store_consistency(): with prep_local_fs_and_fv() as (fs, fv): run_offline_online_store_consistency_test(fs, fv)