From 86c5e14b1f7356574de913a20958160c854f3b1c Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 28 Oct 2024 06:31:51 -0700 Subject: [PATCH 1/4] Document the FileOutput object in the README (#388) --- README.md | 193 ++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 144 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 64a7253d..8c535275 100644 --- a/README.md +++ b/README.md @@ -40,46 +40,7 @@ replacing the model identifier and input with your own: input={"prompt": "a 19th century portrait of a wombat gentleman"} ) -['https://replicate.com/api/models/stability-ai/stable-diffusion/files/50fcac81-865d-499e-81ac-49de0cb79264/out-0.png'] -``` - -> [!TIP] -> You can also use the Replicate client asynchronously by prepending `async_` to the method name. -> -> Here's an example of how to run several predictions concurrently and wait for them all to complete: -> -> ```python -> import asyncio -> import replicate -> -> # https://replicate.com/stability-ai/sdxl -> model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" -> prompts = [ -> f"A chariot pulled by a team of {count} rainbow unicorns" -> for count in ["two", "four", "six", "eight"] -> ] -> -> async with asyncio.TaskGroup() as tg: -> tasks = [ -> tg.create_task(replicate.async_run(model_version, input={"prompt": prompt})) -> for prompt in prompts -> ] -> -> results = await asyncio.gather(*tasks) -> print(results) -> ``` - -To run a model that takes a file input you can pass either -a URL to a publicly accessible file on the Internet -or a handle to a file on your local device. - -```python ->>> output = replicate.run( - "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9", - input={ "image": open("path/to/mystery.jpg") } - ) - -"an astronaut riding a horse" +[] ``` `replicate.run` raises `ModelError` if the prediction fails. @@ -99,6 +60,55 @@ except ModelError as e print("Failed prediction: " + e.prediction.id) ``` +> [!NOTE] +> By default the Replicate client will hold the connection open for up to 60 seconds while waiting +> for the prediction to complete. This is designed to optimize getting the model output back to the +> client as quickly as possible. For models that output files the file data will be inlined into +> the response as a data-uri. +> +> The timeout can be configured by passing `wait=x` to `replicate.run()` where `x` is a timeout +> in seconds between 1 and 60. To disable the sync mode and the data-uri response you can pass +> `wait=False` to `replicate.run()`. + +## AsyncIO support + +You can also use the Replicate client asynchronously by prepending `async_` to the method name. + +Here's an example of how to run several predictions concurrently and wait for them all to complete: + +```python +import asyncio +import replicate + +# https://replicate.com/stability-ai/sdxl +model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" +prompts = [ + f"A chariot pulled by a team of {count} rainbow unicorns" + for count in ["two", "four", "six", "eight"] +] + +async with asyncio.TaskGroup() as tg: + tasks = [ + tg.create_task(replicate.async_run(model_version, input={"prompt": prompt})) + for prompt in prompts + ] + +results = await asyncio.gather(*tasks) +print(results) +``` + +To run a model that takes a file input you can pass either +a URL to a publicly accessible file on the Internet +or a handle to a file on your local device. + +```python +>>> output = replicate.run( + "andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9", + input={ "image": open("path/to/mystery.jpg") } + ) + +"an astronaut riding a horse" +``` ## Run a model and stream its output @@ -176,7 +186,7 @@ iteration: 30, render:loss: -1.3994140625 'succeeded' >>> prediction.output -'https://.../output.png' + ``` ## Run a model in the background and get a webhook @@ -217,8 +227,9 @@ iterator = replicate.run( input={"prompts": "san francisco sunset"} ) -for image in iterator: - display(image) +for index, image in enumerate(iterator): + with open(f"file_{index}.png", "wb") as file: + file.write(image.read()) ``` ## Cancel a prediction @@ -263,20 +274,104 @@ if page1.next: ## Load output files -Output files are returned as HTTPS URLs. You can load an output file as a buffer: +Output files are returned as `FileOutput` objects: ```python import replicate -from PIL import Image -from urllib.request import urlretrieve +from PIL import Image # pip install pillow -out = replicate.run( +output = replicate.run( "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", input={"prompt": "wavy colorful abstract patterns, oceans"} ) -urlretrieve(out[0], "/tmp/out.png") -background = Image.open("/tmp/out.png") +# This has a .read() method that returns the binary data. +with open("my_output.png", "wb") as file: + file.write(output[0].read()) + +# It also implements the iterator protocol to stream the data. +background = Image.open(output[0]) +``` + +### FileOutput + +Is a file-like object returned from the `replicate.run()` method that makes it easier to work with models +that output files. It implements `Iterator` and `AsyncIterator` for reading the file data in chunks as well +as `read` and `aread()` to read the entire file into memory. + +Lastly, the underlying datasource is available on the `url` attribute. + +> [!NOTE] +> The `url` attribute can vary between a remote URL and a data-uri depending on whether the server has +> optimized the request. For small files <5mb using the syncronous API data-uris will be returned to +> remove the need to make subsequent requests for the file data. To disable this pass `wait=false` +> to the replicate.run() function. + +To access the file URL: + +```python +print(output.url) #=> "..." or "https://delivery.replicate.com/..." +``` + +To consume the file directly: + +```python +with open('output.bin', 'wb') as file: + file.write(output.read()) +``` + +Or for very large files they can be streamed: + +```python +with open(file_path, 'wb') as file: + for chunk in output: + file.write(chunk) +``` + +Each of these methods has an equivalent `asyncio` API. + +```python +async with aiofiles.open(filename, 'w') as file: + await file.write(await output.aread()) + +async with aiofiles.open(filename, 'w') as file: + await for chunk in output: + await file.write(chunk) +``` + +For streaming responses from common frameworks, all support taking `Iterator` types: + +**Django** + +```python +@condition(etag_func=None) +def stream_response(request): + output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_output =True) + return HttpResponse(output, content_type='image/webp') +``` + +**FastAPI** + +```python +@app.get("/") +async def main(): + output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_output =True) + return StreamingResponse(output) +``` + +**Flask** + +```python +@app.route('/stream') +def streamed_response(): + output = replicate.run("black-forest-labs/flux-schnell", input={...}, use_file_output =True) + return app.response_class(stream_with_context(output)) +``` + +You can opt out of `FileOutput` by passing `use_file_output=False` to the `replicate.run()` method. + +```python +const replicate = replicate.run("acmecorp/acme-model", use_file_output=False); ``` ## List models From 4fdd78fb6ffe3f2e5d388d109ba12612dbc5552c Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Wed, 30 Oct 2024 04:32:11 -0700 Subject: [PATCH 2/4] Remove mention of data URLs from README (#391) This commit incorporates the changes from #386 as well as removing mention of the data URLs added in 86c5e14. --- README.md | 77 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 39 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 8c535275..ff119975 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,21 @@ This is a Python client for [Replicate](https://replicate.com). It lets you run models from your Python code or Jupyter notebook, and do various other things on Replicate. +## Breaking Changes in 1.0.0 + +The 1.0.0 release contains breaking changes: + +- The `replicate.run()` method now returns `FileOutput`s instead of URL strings by default for models that output files. `FileOutput` implements an iterable interface similar to `httpx.Response`, making it easier to work with files efficiently. + +To revert to the previous behavior, you can opt out of `FileOutput` by passing `use_file_output=False` to `replicate.run()`: + +```python +output = replicate.run("acmecorp/acme-model", use_file_output=False) +``` + +In most cases, updating existing applications to call `output.url` should resolve any issues. But we recommend using the `FileOutput` objects directly as we have further improvements planned to this API and this approach is guaranteed to give the fastest results. + +> [!TIP] > **👋** Check out an interactive version of this tutorial on [Google Colab](https://colab.research.google.com/drive/1K91q4p-OhL96FHBAVLsv9FlwFdu6Pn3c). > > [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1K91q4p-OhL96FHBAVLsv9FlwFdu6Pn3c) @@ -30,17 +45,18 @@ We recommend not adding the token directly to your source code, because you don' ## Run a model -Create a new Python file and add the following code, -replacing the model identifier and input with your own: +Create a new Python file and add the following code, replacing the model identifier and input with your own: ```python >>> import replicate ->>> replicate.run( - "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478", - input={"prompt": "a 19th century portrait of a wombat gentleman"} +>>> outputs = replicate.run( + "black-forest-labs/flux-schnell", + input={"prompt": "astronaut riding a rocket like a horse"} ) - [] +>>> for index, output in enumerate(outputs): + with open(f"output_{index}.webp", "wb") as file: + file.write(output.read()) ``` `replicate.run` raises `ModelError` if the prediction fails. @@ -63,12 +79,10 @@ except ModelError as e > [!NOTE] > By default the Replicate client will hold the connection open for up to 60 seconds while waiting > for the prediction to complete. This is designed to optimize getting the model output back to the -> client as quickly as possible. For models that output files the file data will be inlined into -> the response as a data-uri. +> client as quickly as possible. > > The timeout can be configured by passing `wait=x` to `replicate.run()` where `x` is a timeout -> in seconds between 1 and 60. To disable the sync mode and the data-uri response you can pass -> `wait=False` to `replicate.run()`. +> in seconds between 1 and 60. To disable the sync mode you can pass `wait=False`. ## AsyncIO support @@ -152,7 +166,7 @@ For more information, see ## Run a model in the background -You can start a model and run it in the background: +You can start a model and run it in the background using async mode: ```python >>> model = replicate.models.get("kvfrans/clipdraw") @@ -187,6 +201,9 @@ iteration: 30, render:loss: -1.3994140625 >>> prediction.output + +>>> with open("output.png", "wb") as file: + file.write(prediction.output.read()) ``` ## Run a model in the background and get a webhook @@ -295,19 +312,12 @@ background = Image.open(output[0]) ### FileOutput -Is a file-like object returned from the `replicate.run()` method that makes it easier to work with models -that output files. It implements `Iterator` and `AsyncIterator` for reading the file data in chunks as well -as `read` and `aread()` to read the entire file into memory. - -Lastly, the underlying datasource is available on the `url` attribute. +Is a [file-like](https://docs.python.org/3/glossary.html#term-file-object) object returned from the `replicate.run()` method that makes it easier to work with models that output files. It implements `Iterator` and `AsyncIterator` for reading the file data in chunks as well as `read()` and `aread()` to read the entire file into memory. > [!NOTE] -> The `url` attribute can vary between a remote URL and a data-uri depending on whether the server has -> optimized the request. For small files <5mb using the syncronous API data-uris will be returned to -> remove the need to make subsequent requests for the file data. To disable this pass `wait=false` -> to the replicate.run() function. +> It is worth noting that at this time `read()` and `aread()` do not currently accept a `size` argument to read up to `size` bytes. -To access the file URL: +Lastly, the URL of the underlying data source is available on the `url` attribute though we recommend you use the object as an iterator or use its `read()` or `aread()` methods, as the `url` property may not always return HTTP URLs in future. ```python print(output.url) #=> "..." or "https://delivery.replicate.com/..." @@ -439,13 +449,9 @@ Here's how to list of all the available hardware for running models on Replicate ## Fine-tune a model -Use the [training API](https://replicate.com/docs/fine-tuning) -to fine-tune models to make them better at a particular task. -To see what **language models** currently support fine-tuning, -check out Replicate's [collection of trainable language models](https://replicate.com/collections/trainable-language-models). +Use the [training API](https://replicate.com/docs/fine-tuning) to fine-tune models to make them better at a particular task. To see what **language models** currently support fine-tuning, check out Replicate's [collection of trainable language models](https://replicate.com/collections/trainable-language-models). -If you're looking to fine-tune **image models**, -check out Replicate's [guide to fine-tuning image models](https://replicate.com/docs/guides/fine-tune-an-image-model). +If you're looking to fine-tune **image models**, check out Replicate's [guide to fine-tuning image models](https://replicate.com/docs/guides/fine-tune-an-image-model). Here's how to fine-tune a model on Replicate: @@ -467,24 +473,19 @@ training = replicate.trainings.create( ## Customize client behavior -The `replicate` package exports a default shared client. -This client is initialized with an API token -set by the `REPLICATE_API_TOKEN` environment variable. +The `replicate` package exports a default shared client. This client is initialized with an API token set by the `REPLICATE_API_TOKEN` environment variable. -You can create your own client instance to -pass a different API token value, -add custom headers to requests, -or control the behavior of the underlying [HTTPX client](https://www.python-httpx.org/api/#client): +You can create your own client instance to pass a different API token value, add custom headers to requests, or control the behavior of the underlying [HTTPX client](https://www.python-httpx.org/api/#client): ```python import os from replicate.client import Client replicate = Client( - api_token=os.environ["SOME_OTHER_REPLICATE_API_TOKEN"] - headers={ - "User-Agent": "my-app/1.0" - } + api_token=os.environ["SOME_OTHER_REPLICATE_API_TOKEN"] + headers={ + "User-Agent": "my-app/1.0" + } ) ``` From 07c8fbbb21c14b360e825208f4b1632154d8a458 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 15 Nov 2024 13:41:31 +0000 Subject: [PATCH 3/4] Fix a couple of bugs in the base64 file_encoding_strategy (#398) This commit adds tests for the `file_encoding_strategy` argument for `replicate.run()` and fixes two bugs that surfaced: 1. `replicate.run()` would convert the file provided into base64 encoded data but not a valid data URL. We now use the `base64_encode_file` function used for outputs. 2. `replicate.async_run()` accepted but did not use the `file_encoding_strategy` flag at all. This is fixed, though it is worth noting that `base64_encode_file` is not optimized for async workflows and will block. This might be okay as the file sizes expected for data URL payloads should be very small. --- replicate/helpers.py | 10 +++- tests/test_run.py | 129 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 3 deletions(-) diff --git a/replicate/helpers.py b/replicate/helpers.py index e0bada5d..c6ac9072 100644 --- a/replicate/helpers.py +++ b/replicate/helpers.py @@ -43,7 +43,7 @@ def encode_json( return encode_json(file, client, file_encoding_strategy) if isinstance(obj, io.IOBase): if file_encoding_strategy == "base64": - return base64.b64encode(obj.read()).decode("utf-8") + return base64_encode_file(obj) else: return client.files.create(obj).urls["get"] if HAS_NUMPY: @@ -77,9 +77,13 @@ async def async_encode_json( ] if isinstance(obj, Path): with obj.open("rb") as file: - return encode_json(file, client, file_encoding_strategy) + return await async_encode_json(file, client, file_encoding_strategy) if isinstance(obj, io.IOBase): - return (await client.files.async_create(obj)).urls["get"] + if file_encoding_strategy == "base64": + # TODO: This should ideally use an async based file reader path. + return base64_encode_file(obj) + else: + return (await client.files.async_create(obj)).urls["get"] if HAS_NUMPY: if isinstance(obj, np.integer): # type: ignore return int(obj) diff --git a/tests/test_run.py b/tests/test_run.py index beb7f6e2..93f7248b 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,5 +1,10 @@ import asyncio +import io +import json import sys +from email.message import EmailMessage +from email.parser import BytesParser +from email.policy import HTTP from typing import AsyncIterator, Iterator, Optional, cast import httpx @@ -581,6 +586,130 @@ async def test_run_with_model_error(mock_replicate_api_token): assert excinfo.value.prediction.status == "failed" +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_run_with_file_input_files_api(async_flag, mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + mock_predictions_create = router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("processing"), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 200, + json=_version_with_schema(), + ) + ) + mock_files_create = router.route( + method="POST", + path="/files", + ).mock( + return_value=httpx.Response( + 200, + json={ + "id": "file1", + "name": "file.png", + "content_type": "image/png", + "size": 10, + "etag": "123", + "checksums": {}, + "metadata": {}, + "created_at": "", + "expires_at": "", + "urls": {"get": "https://api.replicate.com/files/file.txt"}, + }, + ) + ) + router.route(host="api.replicate.com").pass_through() + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + if async_flag: + await client.async_run( + "test/example:v1", + input={"file": io.BytesIO(initial_bytes=b"hello world")}, + ) + else: + client.run( + "test/example:v1", + input={"file": io.BytesIO(initial_bytes=b"hello world")}, + ) + + assert mock_predictions_create.called + prediction_payload = json.loads(mock_predictions_create.calls[0].request.content) + assert ( + prediction_payload.get("input", {}).get("file") + == "https://api.replicate.com/files/file.txt" + ) + + # Validate the Files API request + req = mock_files_create.calls[0].request + body = req.content + content_type = req.headers["Content-Type"] + + # Parse the multipart data + parser = BytesParser(EmailMessage, policy=HTTP) + headers = f"Content-Type: {content_type}\n\n".encode() + parsed_message_generator = parser.parsebytes(headers + body).walk() + next(parsed_message_generator) # wrapper + input_file = next(parsed_message_generator) + assert mock_files_create.called + assert input_file.get_content() == b"hello world" + assert input_file.get_content_type() == "application/octet-stream" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_run_with_file_input_data_url(async_flag, mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + mock_predictions_create = router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("processing"), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 200, + json=_version_with_schema(), + ) + ) + router.route(host="api.replicate.com").pass_through() + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + if async_flag: + await client.async_run( + "test/example:v1", + input={"file": io.BytesIO(initial_bytes=b"hello world")}, + file_encoding_strategy="base64", + ) + else: + client.run( + "test/example:v1", + input={"file": io.BytesIO(initial_bytes=b"hello world")}, + file_encoding_strategy="base64", + ) + + assert mock_predictions_create.called + prediction_payload = json.loads(mock_predictions_create.calls[0].request.content) + assert ( + prediction_payload.get("input", {}).get("file") + == "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=" + ) + + @pytest.mark.asyncio async def test_run_with_file_output(mock_replicate_api_token): router = respx.Router(base_url="https://api.replicate.com/v1") From 461ec70f566b0b993a94966ea73aa80a03eb5d60 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Mon, 25 Nov 2024 12:35:11 +0000 Subject: [PATCH 4/4] Bump version to 1.0.4 (#401) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c5c7876a..975f7d92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "replicate" -version = "1.0.3" +version = "1.0.4" description = "Python client for Replicate" readme = "README.md" license = { file = "LICENSE" }