diff --git a/.secrets.baseline b/.secrets.baseline index 96bf780809c..bdbf199f662 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -1426,7 +1426,7 @@ "filename": "sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py", "hashed_secret": "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3", "is_verified": false, - "line_number": 10 + "line_number": 14 } ], "sdk/python/tests/unit/local_feast_tests/test_init.py": [ @@ -1539,5 +1539,5 @@ } ] }, - "generated_at": "2026-04-07T15:56:56Z" + "generated_at": "2026-04-09T03:30:18Z" } diff --git a/sdk/python/feast/infra/registry/snowflake.py b/sdk/python/feast/infra/registry/snowflake.py index 68bf376e650..248051ff16c 100644 --- a/sdk/python/feast/infra/registry/snowflake.py +++ b/sdk/python/feast/infra/registry/snowflake.py @@ -139,8 +139,9 @@ def __init__( with GetSnowflakeConnection(self.registry_config) as conn: sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_creation.sql" with open(sql_function_file, "r") as file: - sql_file = file.read() - sql_cmds = sql_file.split(";") + sql_cmds = [ + cmd.strip() for cmd in file.read().split(";") if cmd.strip() + ] for command in sql_cmds: query = command.replace("REGISTRY_PATH", f"{self.registry_path}") execute_snowflake_statement(conn, query) @@ -224,9 +225,10 @@ def teardown(self): with GetSnowflakeConnection(self.registry_config) as conn: sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql" with open(sql_function_file, "r") as file: - sqlFile = file.read() - sqlCommands = sqlFile.split(";") - for command in sqlCommands: + sql_cmds = [ + cmd.strip() for cmd in file.read().split(";") if cmd.strip() + ] + for command in sql_cmds: query = command.replace("REGISTRY_PATH", f"{self.registry_path}") execute_snowflake_statement(conn, query) diff --git a/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_deletion.sql b/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_deletion.sql index 780424abd17..dde984c3823 100644 --- a/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_deletion.sql +++ b/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_deletion.sql @@ -1,3 +1,5 @@ +DROP TABLE IF EXISTS REGISTRY_PATH."PROJECTS"; + DROP TABLE IF EXISTS REGISTRY_PATH."DATA_SOURCES"; DROP TABLE IF EXISTS REGISTRY_PATH."ENTITIES"; @@ -16,6 +18,6 @@ DROP TABLE IF EXISTS REGISTRY_PATH."SAVED_DATASETS"; DROP TABLE IF EXISTS REGISTRY_PATH."STREAM_FEATURE_VIEWS"; -DROP TABLE IF EXISTS REGISTRY_PATH."VALIDATION_REFERENCES" +DROP TABLE IF EXISTS REGISTRY_PATH."VALIDATION_REFERENCES"; -DROP TABLE IF EXISTS REGISTRY_PATH."PERMISSIONS" +DROP TABLE IF EXISTS REGISTRY_PATH."PERMISSIONS"; diff --git a/sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py b/sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py index 8ae6ec63ba5..14b42c9783b 100644 --- a/sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py +++ b/sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py @@ -1,11 +1,15 @@ import tempfile from typing import Optional +from unittest.mock import MagicMock import pytest from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa -from feast.infra.utils.snowflake.snowflake_utils import parse_private_key_path +from feast.infra.utils.snowflake.snowflake_utils import ( + execute_snowflake_statement, + parse_private_key_path, +) PRIVATE_KEY_PASSPHRASE = "test" @@ -69,3 +73,38 @@ def test_parse_private_key_path_key_path_encrypted(encrypted_private_key): f.name, None, ) + + +class TestExecuteSnowflakeStatement: + def test_empty_query_is_passed_through_to_execute(self): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_executed_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_cursor.execute.return_value = mock_executed_cursor + + result = execute_snowflake_statement(mock_conn, "") + + assert result is mock_executed_cursor + mock_cursor.execute.assert_called_once_with("") + + def test_valid_query_executes_and_returns_cursor(self): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_executed_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_cursor.execute.return_value = mock_executed_cursor + + result = execute_snowflake_statement(mock_conn, "SELECT 1") + + assert result is mock_executed_cursor + mock_cursor.execute.assert_called_once_with("SELECT 1") + + def test_valid_query_raises_on_none_cursor(self): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_conn.cursor.return_value = mock_cursor + mock_cursor.execute.return_value = None + + with pytest.raises(Exception, match="Snowflake query failed"): + execute_snowflake_statement(mock_conn, "SELECT 1")