Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 140 additions & 73 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel):
type: Literal["dynamodb"] = "dynamodb"
"""Online store type selector"""

batch_size: int = 40
"""Number of items to retrieve in a DynamoDB BatchGetItem call."""
batch_size: int = 100
"""Number of items to retrieve in a DynamoDB BatchGetItem call.
DynamoDB supports up to 100 items per BatchGetItem request."""

endpoint_url: Union[str, None] = None
"""DynamoDB local development endpoint Url, i.e. http://localhost:8000"""
"""DynamoDB endpoint URL. Use for local development (e.g., http://localhost:8000)
or VPC endpoints for improved latency."""

region: StrictStr
"""AWS Region Name"""
Expand All @@ -74,30 +76,33 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel):
session_based_auth: bool = False
"""AWS session based client authentication"""

max_pool_connections: int = 10
"""Max number of connections for async Dynamodb operations"""
max_pool_connections: int = 50
"""Max number of connections for async Dynamodb operations.
Increase for high-throughput workloads."""

keepalive_timeout: float = 12.0
"""Keep-alive timeout in seconds for async Dynamodb connections."""
keepalive_timeout: float = 30.0
"""Keep-alive timeout in seconds for async Dynamodb connections.
Higher values help reuse connections under sustained load."""

connect_timeout: Union[int, float] = 60
connect_timeout: Union[int, float] = 5
"""The time in seconds until a timeout exception is thrown when attempting to make
an async connection."""
an async connection. Lower values enable faster failure detection."""

read_timeout: Union[int, float] = 60
read_timeout: Union[int, float] = 10
"""The time in seconds until a timeout exception is thrown when attempting to read
from an async connection."""
from an async connection. Lower values enable faster failure detection."""

total_max_retry_attempts: Union[int, None] = None
total_max_retry_attempts: Union[int, None] = 3
"""Maximum number of total attempts that will be made on a single request.

Maps to `retries.total_max_attempts` in botocore.config.Config.
"""

retry_mode: Union[Literal["legacy", "standard", "adaptive"], None] = None
retry_mode: Union[Literal["legacy", "standard", "adaptive"], None] = "adaptive"
"""The type of retry mode (aio)botocore should use.

Maps to `retries.mode` in botocore.config.Config.
'adaptive' mode provides intelligent retry with client-side rate limiting.
"""


Expand All @@ -111,16 +116,22 @@ class DynamoDBOnlineStore(OnlineStore):
_aioboto_session: Async boto session.
_aioboto_client: Async boto client.
_aioboto_context_stack: Async context stack.
_type_deserializer: Cached TypeDeserializer instance for performance.
"""

_dynamodb_client = None
_dynamodb_resource = None
# Class-level cached TypeDeserializer to avoid per-request instantiation
_type_deserializer: TypeDeserializer = None

def __init__(self):
super().__init__()
self._aioboto_session = None
self._aioboto_client = None
self._aioboto_context_stack = None
# Initialize cached TypeDeserializer if not already done
if DynamoDBOnlineStore._type_deserializer is None:
DynamoDBOnlineStore._type_deserializer = TypeDeserializer()

async def initialize(self, config: RepoConfig):
online_config = config.online_store
Expand All @@ -133,6 +144,7 @@ async def initialize(self, config: RepoConfig):
online_config.read_timeout,
online_config.total_max_retry_attempts,
online_config.retry_mode,
online_config.endpoint_url,
)

async def close(self):
Expand All @@ -153,6 +165,7 @@ async def _get_aiodynamodb_client(
read_timeout: Union[int, float],
total_max_retry_attempts: Union[int, None],
retry_mode: Union[Literal["legacy", "standard", "adaptive"], None],
endpoint_url: Optional[str] = None,
):
if self._aioboto_client is None:
logger.debug("initializing the aiobotocore dynamodb client")
Expand All @@ -163,16 +176,23 @@ async def _get_aiodynamodb_client(
if retry_mode is not None:
retries["mode"] = retry_mode

client_context = self._get_aioboto_session().create_client(
"dynamodb",
region_name=region,
config=AioConfig(
# Build client kwargs, including endpoint_url for VPC endpoints or local testing
client_kwargs: Dict[str, Any] = {
"region_name": region,
"config": AioConfig(
max_pool_connections=max_pool_connections,
connect_timeout=connect_timeout,
read_timeout=read_timeout,
retries=retries if retries else None,
connector_args={"keepalive_timeout": keepalive_timeout},
),
}
if endpoint_url:
client_kwargs["endpoint_url"] = endpoint_url

client_context = self._get_aioboto_session().create_client(
"dynamodb",
**client_kwargs,
)
self._aioboto_context_stack = contextlib.AsyncExitStack()
self._aioboto_client = (
Expand Down Expand Up @@ -431,6 +451,7 @@ async def online_write_batch_async(
online_config.read_timeout,
online_config.total_max_retry_attempts,
online_config.retry_mode,
online_config.endpoint_url,
)
await dynamo_write_items_async(client, table_name, items)

Expand All @@ -448,6 +469,7 @@ def online_read(
config: The RepoConfig for the current FeatureStore.
table: Feast FeatureView.
entity_keys: a list of entity keys that should be read from the FeatureStore.
requested_features: Optional list of feature names to retrieve.
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)
Expand Down Expand Up @@ -479,7 +501,9 @@ def online_read(
RequestItems=batch_entity_ids,
)
batch_result = self._process_batch_get_response(
table_instance.name, response, entity_ids, batch
table_instance.name,
response,
batch,
)
result.extend(batch_result)
return result
Expand Down Expand Up @@ -513,7 +537,8 @@ async def online_read_async(
entity_ids_iter = iter(entity_ids)
table_name = _get_table_name(online_config, config, table)

deserialize = TypeDeserializer().deserialize
# Use cached TypeDeserializer for better performance
deserialize = self._type_deserializer.deserialize

def to_tbl_resp(raw_client_response):
return {
Expand Down Expand Up @@ -542,6 +567,7 @@ def to_tbl_resp(raw_client_response):
online_config.read_timeout,
online_config.total_max_retry_attempts,
online_config.retry_mode,
online_config.endpoint_url,
)
response_batches = await asyncio.gather(
*[
Expand All @@ -557,7 +583,6 @@ def to_tbl_resp(raw_client_response):
result_batch = self._process_batch_get_response(
table_name,
response,
entity_ids,
batch,
to_tbl_response=to_tbl_resp,
)
Expand Down Expand Up @@ -589,26 +614,6 @@ def _get_dynamodb_resource(
)
return self._dynamodb_resource

def _sort_dynamodb_response(
self,
responses: list,
order: list,
to_tbl_response: Callable = lambda raw_dict: raw_dict,
) -> Any:
"""DynamoDB Batch Get Item doesn't return items in a particular order."""
# Assign an index to order
order_with_index = {value: idx for idx, value in enumerate(order)}
# Sort table responses by index
table_responses_ordered: Any = [
(order_with_index[tbl_res["entity_id"]], tbl_res)
for tbl_res in map(to_tbl_response, responses)
]
table_responses_ordered = sorted(
table_responses_ordered, key=lambda tup: tup[0]
)
_, table_responses_ordered = zip(*table_responses_ordered)
return table_responses_ordered

def _write_batch_non_duplicates(
self,
table_instance,
Expand All @@ -630,44 +635,106 @@ def _write_batch_non_duplicates(
progress(1)

def _process_batch_get_response(
self, table_name, response, entity_ids, batch, **sort_kwargs
):
response = response.get("Responses")
table_responses = response.get(table_name)
self,
table_name: str,
response: Dict[str, Any],
batch: List[str],
to_tbl_response: Callable = lambda raw_dict: raw_dict,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
"""Process batch get response using O(1) dictionary lookup.

batch_result = []
if table_responses:
table_responses = self._sort_dynamodb_response(
table_responses, entity_ids, **sort_kwargs
)
entity_idx = 0
for tbl_res in table_responses:
entity_id = tbl_res["entity_id"]
while entity_id != batch[entity_idx]:
batch_result.append((None, None))
entity_idx += 1
res = {}
for feature_name, value_bin in tbl_res["values"].items():
DynamoDB BatchGetItem doesn't return items in a particular order,
so we use a dictionary for O(1) lookup instead of O(n log n) sorting.

This method:
- Uses dictionary lookup instead of sorting for response ordering
- Pre-allocates the result list with None values
- Minimizes object creation in the hot path

Args:
table_name: Name of the DynamoDB table
response: Raw response from DynamoDB batch_get_item
batch: List of entity_ids in the order they should be returned
to_tbl_response: Function to transform raw DynamoDB response items
(used for async client responses that need deserialization)

Returns:
List of (timestamp, features) tuples in the same order as batch
"""
responses_data = response.get("Responses")
if not responses_data:
# No responses at all, return all None tuples
return [(None, None)] * len(batch)

table_responses = responses_data.get(table_name)
if not table_responses:
# No responses for this table, return all None tuples
return [(None, None)] * len(batch)

# Build a dictionary for O(1) lookup instead of O(n log n) sorting
response_dict: Dict[str, Any] = {
tbl_res["entity_id"]: tbl_res
for tbl_res in map(to_tbl_response, table_responses)
}

# Pre-allocate result list with None tuples (faster than appending)
batch_size = len(batch)
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [
(None, None)
] * batch_size

# Process each entity in batch order using O(1) dict lookup
for idx, entity_id in enumerate(batch):
tbl_res = response_dict.get(entity_id)
if tbl_res is not None:
# Parse feature values
features: Dict[str, ValueProto] = {}
values_data = tbl_res["values"]
for feature_name, value_bin in values_data.items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
batch_result.append((datetime.fromisoformat(tbl_res["event_ts"]), res))
entity_idx += 1
# Not all entities in a batch may have responses
# Pad with remaining values in batch that were not found
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
batch_result.extend(batch_size_nones)
return batch_result
features[feature_name] = val

# Parse timestamp and set result
result[idx] = (
datetime.fromisoformat(tbl_res["event_ts"]),
features,
)

return result

@staticmethod
def _to_entity_ids(config: RepoConfig, entity_keys: List[EntityKeyProto]):
return [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]
def _to_entity_ids(
config: RepoConfig, entity_keys: List[EntityKeyProto]
) -> List[str]:
"""Convert entity keys to entity IDs with caching for repeated entities.

This method caches entity_id computations within a single request to avoid
redundant hashing for duplicate entity keys.
"""
# Use a cache to avoid recomputing entity_ids for duplicate entity keys
# The cache key is the serialized proto (which is deterministic)
entity_id_cache: Dict[bytes, str] = {}
entity_ids: List[str] = []
serialization_version = config.entity_key_serialization_version

for entity_key in entity_keys:
# Use serialized proto as cache key
cache_key = entity_key.SerializeToString()

if cache_key in entity_id_cache:
# Cache hit - reuse computed entity_id
entity_ids.append(entity_id_cache[cache_key])
else:
# Cache miss - compute and store
entity_id = compute_entity_id(
entity_key,
entity_key_serialization_version=serialization_version,
)
entity_id_cache[cache_key] = entity_id
entity_ids.append(entity_id)

return entity_ids

@staticmethod
def _to_resource_batch_get_payload(online_config, table_name, batch):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,17 @@ def test_dynamodb_online_store_config_default():
aws_region = "us-west-2"
dynamodb_store_config = DynamoDBOnlineStoreConfig(region=aws_region)
assert dynamodb_store_config.type == "dynamodb"
assert dynamodb_store_config.batch_size == 40
assert dynamodb_store_config.batch_size == 100
assert dynamodb_store_config.endpoint_url is None
assert dynamodb_store_config.region == aws_region
assert dynamodb_store_config.table_name_template == "{project}.{table_name}"
# Verify other optimized defaults
assert dynamodb_store_config.max_pool_connections == 50
assert dynamodb_store_config.keepalive_timeout == 30.0
assert dynamodb_store_config.connect_timeout == 5
assert dynamodb_store_config.read_timeout == 10
assert dynamodb_store_config.total_max_retry_attempts == 3
assert dynamodb_store_config.retry_mode == "adaptive"


def test_dynamodb_online_store_config_custom_params():
Expand Down
Loading