Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@
"filename": "sdk/python/tests/unit/infra/utils/snowflake/test_snowflake_utils.py",
"hashed_secret": "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3",
"is_verified": false,
"line_number": 10
"line_number": 14
}
],
"sdk/python/tests/unit/local_feast_tests/test_init.py": [
Expand Down Expand Up @@ -1539,5 +1539,5 @@
}
]
},
"generated_at": "2026-04-07T15:56:56Z"
"generated_at": "2026-04-09T03:30:18Z"
}
12 changes: 7 additions & 5 deletions sdk/python/feast/infra/registry/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ def __init__(
with GetSnowflakeConnection(self.registry_config) as conn:
sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_creation.sql"
with open(sql_function_file, "r") as file:
sql_file = file.read()
sql_cmds = sql_file.split(";")
sql_cmds = [
cmd.strip() for cmd in file.read().split(";") if cmd.strip()
]
for command in sql_cmds:
query = command.replace("REGISTRY_PATH", f"{self.registry_path}")
execute_snowflake_statement(conn, query)
Expand Down Expand Up @@ -224,9 +225,10 @@ def teardown(self):
with GetSnowflakeConnection(self.registry_config) as conn:
sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql"
with open(sql_function_file, "r") as file:
sqlFile = file.read()
sqlCommands = sqlFile.split(";")
for command in sqlCommands:
sql_cmds = [
cmd.strip() for cmd in file.read().split(";") if cmd.strip()
]
for command in sql_cmds:
query = command.replace("REGISTRY_PATH", f"{self.registry_path}")
execute_snowflake_statement(conn, query)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
DROP TABLE IF EXISTS REGISTRY_PATH."PROJECTS";

DROP TABLE IF EXISTS REGISTRY_PATH."DATA_SOURCES";

DROP TABLE IF EXISTS REGISTRY_PATH."ENTITIES";
Expand All @@ -16,6 +18,6 @@ DROP TABLE IF EXISTS REGISTRY_PATH."SAVED_DATASETS";

DROP TABLE IF EXISTS REGISTRY_PATH."STREAM_FEATURE_VIEWS";

DROP TABLE IF EXISTS REGISTRY_PATH."VALIDATION_REFERENCES"
DROP TABLE IF EXISTS REGISTRY_PATH."VALIDATION_REFERENCES";

DROP TABLE IF EXISTS REGISTRY_PATH."PERMISSIONS"
DROP TABLE IF EXISTS REGISTRY_PATH."PERMISSIONS";
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import tempfile
from typing import Optional
from unittest.mock import MagicMock

import pytest
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa

from feast.infra.utils.snowflake.snowflake_utils import parse_private_key_path
from feast.infra.utils.snowflake.snowflake_utils import (
execute_snowflake_statement,
parse_private_key_path,
)

PRIVATE_KEY_PASSPHRASE = "test"

Expand Down Expand Up @@ -69,3 +73,38 @@ def test_parse_private_key_path_key_path_encrypted(encrypted_private_key):
f.name,
None,
)


class TestExecuteSnowflakeStatement:
def test_empty_query_is_passed_through_to_execute(self):
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_executed_cursor = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_cursor.execute.return_value = mock_executed_cursor

result = execute_snowflake_statement(mock_conn, "")

assert result is mock_executed_cursor
mock_cursor.execute.assert_called_once_with("")

def test_valid_query_executes_and_returns_cursor(self):
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_executed_cursor = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_cursor.execute.return_value = mock_executed_cursor

result = execute_snowflake_statement(mock_conn, "SELECT 1")

assert result is mock_executed_cursor
mock_cursor.execute.assert_called_once_with("SELECT 1")

def test_valid_query_raises_on_none_cursor(self):
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_conn.cursor.return_value = mock_cursor
mock_cursor.execute.return_value = None

with pytest.raises(Exception, match="Snowflake query failed"):
execute_snowflake_statement(mock_conn, "SELECT 1")