From 7cfd9842a79e8aee8929ca329b99c4f4ab12be7c Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 16 Oct 2024 07:40:06 -0700 Subject: [PATCH 1/2] Configure read timeout based on `wait` parameter (#373) This fixes a bug in the new `wait` implementation where the default read timeout for the HTTP client is shorter than the timeout on the server. This results in the client erroring before the server has had the opportunity to respond with a partial prediction. This commit now provides a custom timeout for the `predictions.create` request based on the `wait` parameter provided. We add a 500ms buffer to the timeout to account for some discrepancy between server and client timings. I attempted to try and refactor the shared code between models, deployments & predictions but gave up. We now have a single function that creates the `headers` and `timeout` params and passes them in at the various call sites. --- replicate/deployment.py | 52 +++++++++++++++++----------------- replicate/model.py | 37 ++++++++++-------------- replicate/prediction.py | 62 ++++++++++++++++++++++++++++++----------- 3 files changed, 86 insertions(+), 65 deletions(-) diff --git a/replicate/deployment.py b/replicate/deployment.py index 8dfb6e7..21fc990 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -8,7 +8,7 @@ from replicate.prediction import ( Prediction, _create_prediction_body, - _create_prediction_headers, + _create_prediction_request_params, _json_to_prediction, ) from replicate.resource import Namespace, Resource @@ -421,21 +421,25 @@ def create( Create a new prediction with the deployment. """ + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) + if input is not None: input = encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) - body = _create_prediction_body(version=None, input=input, **params) + body = _create_prediction_body(version=None, input=input, **params) + extras = _create_prediction_request_params( + wait=wait, + ) resp = self._client._request( "POST", f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", json=body, - headers=headers, + **extras, ) return _json_to_prediction(self._client, resp.json()) @@ -449,6 +453,7 @@ async def async_create( Create a new prediction with the deployment. """ + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) if input is not None: input = await async_encode_json( @@ -456,14 +461,16 @@ async def async_create( client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) - body = _create_prediction_body(version=None, input=input, **params) + body = _create_prediction_body(version=None, input=input, **params) + extras = _create_prediction_request_params( + wait=wait, + ) resp = await self._client._async_request( "POST", f"/v1/deployments/{self._deployment.owner}/{self._deployment.name}/predictions", json=body, - headers=headers, + **extras, ) return _json_to_prediction(self._client, resp.json()) @@ -484,24 +491,20 @@ def create( Create a new prediction with the deployment. """ - url = _create_prediction_url_from_deployment(deployment) - + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) + + url = _create_prediction_url_from_deployment(deployment) if input is not None: input = encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) - body = _create_prediction_body(version=None, input=input, **params) - resp = self._client._request( - "POST", - url, - json=body, - headers=headers, - ) + body = _create_prediction_body(version=None, input=input, **params) + extras = _create_prediction_request_params(wait=wait) + resp = self._client._request("POST", url, json=body, **extras) return _json_to_prediction(self._client, resp.json()) @@ -515,9 +518,10 @@ async def async_create( Create a new prediction with the deployment. """ - url = _create_prediction_url_from_deployment(deployment) - + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) + + url = _create_prediction_url_from_deployment(deployment) if input is not None: input = await async_encode_json( input, @@ -525,15 +529,9 @@ async def async_create( file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) body = _create_prediction_body(version=None, input=input, **params) - - resp = await self._client._async_request( - "POST", - url, - json=body, - headers=headers, - ) + extras = _create_prediction_request_params(wait=wait) + resp = await self._client._async_request("POST", url, json=body, **extras) return _json_to_prediction(self._client, resp.json()) diff --git a/replicate/model.py b/replicate/model.py index 1cf144a..a52459e 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -9,7 +9,7 @@ from replicate.prediction import ( Prediction, _create_prediction_body, - _create_prediction_headers, + _create_prediction_request_params, _json_to_prediction, ) from replicate.resource import Namespace, Resource @@ -389,24 +389,20 @@ def create( Create a new prediction with the deployment. """ - url = _create_prediction_url_from_model(model) - + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) + + path = _create_prediction_path_from_model(model) if input is not None: input = encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) - body = _create_prediction_body(version=None, input=input, **params) - resp = self._client._request( - "POST", - url, - json=body, - headers=headers, - ) + body = _create_prediction_body(version=None, input=input, **params) + extras = _create_prediction_request_params(wait=wait) + resp = self._client._request("POST", path, json=body, **extras) return _json_to_prediction(self._client, resp.json()) @@ -420,24 +416,21 @@ async def async_create( Create a new prediction with the deployment. """ - url = _create_prediction_url_from_model(model) - + wait = params.pop("wait", None) file_encoding_strategy = params.pop("file_encoding_strategy", None) + + path = _create_prediction_path_from_model(model) + if input is not None: input = await async_encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) - body = _create_prediction_body(version=None, input=input, **params) - resp = await self._client._async_request( - "POST", - url, - json=body, - headers=headers, - ) + body = _create_prediction_body(version=None, input=input, **params) + extras = _create_prediction_request_params(wait=wait) + resp = await self._client._async_request("POST", path, json=body, **extras) return _json_to_prediction(self._client, resp.json()) @@ -522,7 +515,7 @@ def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model: return model -def _create_prediction_url_from_model( +def _create_prediction_path_from_model( model: Union[str, Tuple[str, str], "Model"], ) -> str: owner, name = None, None diff --git a/replicate/prediction.py b/replicate/prediction.py index aa3e45c..b4ff047 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -16,6 +16,7 @@ overload, ) +import httpx from typing_extensions import NotRequired, TypedDict, Unpack from replicate.exceptions import ModelError, ReplicateError @@ -446,6 +447,9 @@ def create( # type: ignore Create a new prediction for the specified model, version, or deployment. """ + wait = params.pop("wait", None) + file_encoding_strategy = params.pop("file_encoding_strategy", None) + if args: version = args[0] if len(args) > 0 else None input = args[1] if len(args) > 1 else input @@ -477,26 +481,20 @@ def create( # type: ignore **params, ) - file_encoding_strategy = params.pop("file_encoding_strategy", None) if input is not None: input = encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) + body = _create_prediction_body( version, input, **params, ) - - resp = self._client._request( - "POST", - "/v1/predictions", - headers=headers, - json=body, - ) + extras = _create_prediction_request_params(wait=wait) + resp = self._client._request("POST", "/v1/predictions", json=body, **extras) return _json_to_prediction(self._client, resp.json()) @@ -538,6 +536,8 @@ async def async_create( # type: ignore """ Create a new prediction for the specified model, version, or deployment. """ + wait = params.pop("wait", None) + file_encoding_strategy = params.pop("file_encoding_strategy", None) if args: version = args[0] if len(args) > 0 else None @@ -570,25 +570,21 @@ async def async_create( # type: ignore **params, ) - file_encoding_strategy = params.pop("file_encoding_strategy", None) if input is not None: input = await async_encode_json( input, client=self._client, file_encoding_strategy=file_encoding_strategy, ) - headers = _create_prediction_headers(wait=params.pop("wait", None)) + body = _create_prediction_body( version, input, **params, ) - + extras = _create_prediction_request_params(wait=wait) resp = await self._client._async_request( - "POST", - "/v1/predictions", - headers=headers, - json=body, + "POST", "/v1/predictions", json=body, **extras ) return _json_to_prediction(self._client, resp.json()) @@ -628,6 +624,40 @@ async def async_cancel(self, id: str) -> Prediction: return _json_to_prediction(self._client, resp.json()) +class CreatePredictionRequestParams(TypedDict): + headers: NotRequired[Optional[dict]] + timeout: NotRequired[Optional[httpx.Timeout]] + + +def _create_prediction_request_params( + wait: Optional[Union[int, bool]], +) -> CreatePredictionRequestParams: + timeout = _create_prediction_timeout(wait=wait) + headers = _create_prediction_headers(wait=wait) + + return { + "headers": headers, + "timeout": timeout, + } + + +def _create_prediction_timeout( + *, wait: Optional[Union[int, bool]] = None +) -> Union[httpx.Timeout, None]: + """ + Returns an `httpx.Timeout` instances appropriate for the optional + `Prefer: wait=x` header that can be provided with the request. This + will ensure that we give the server enough time to respond with + a partial prediction in the event that the request times out. + """ + + if not wait: + return None + + read_timeout = 60.0 if isinstance(wait, bool) else wait + return httpx.Timeout(5.0, read=read_timeout + 0.5) + + def _create_prediction_headers( *, wait: Optional[Union[int, bool]] = None, From 23bd9031310d3bad08445b3d0ae6e6a72db232f6 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 16 Oct 2024 07:49:45 -0700 Subject: [PATCH 2/2] v1.0.2 (#376) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 705c9fb..8441d4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "replicate" -version = "1.0.1" +version = "1.0.2" description = "Python client for Replicate" readme = "README.md" license = { file = "LICENSE" }