diff --git a/replicate/prediction.py b/replicate/prediction.py index b4ff047..d83d926 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -55,7 +55,7 @@ class Prediction(Resource): version: str """An identifier for the version of the model used to create the prediction.""" - status: Literal["starting", "processing", "succeeded", "failed", "canceled"] + status: Literal["starting", "processing", "succeeded", "failed", "canceled", "aborted"] """The status of the prediction.""" input: Optional[Dict[str, Any]] @@ -141,7 +141,7 @@ def wait(self) -> None: Wait for prediction to finish. """ - while self.status not in ["succeeded", "failed", "canceled"]: + while self.status not in ["succeeded", "failed", "canceled", "aborted"]: time.sleep(self._client.poll_interval) self.reload() @@ -150,7 +150,7 @@ async def async_wait(self) -> None: Wait for prediction to finish asynchronously. """ - while self.status not in ["succeeded", "failed", "canceled"]: + while self.status not in ["succeeded", "failed", "canceled", "aborted"]: await asyncio.sleep(self._client.poll_interval) await self.async_reload() @@ -251,7 +251,7 @@ def output_iterator(self) -> Iterator[Any]: # TODO: check output is list previous_output = self.output or [] - while self.status not in ["succeeded", "failed", "canceled"]: + while self.status not in ["succeeded", "failed", "canceled", "aborted"]: output = self.output or [] new_output = output[len(previous_output) :] yield from new_output @@ -259,7 +259,7 @@ def output_iterator(self) -> Iterator[Any]: time.sleep(self._client.poll_interval) # pylint: disable=no-member self.reload() - if self.status == "failed": + if self.status in ("failed", "aborted"): raise ModelError(self) output = self.output or [] @@ -273,7 +273,7 @@ async def async_output_iterator(self) -> AsyncIterator[Any]: # TODO: check output is list previous_output = self.output or [] - while self.status not in ["succeeded", "failed", "canceled"]: + while self.status not in ["succeeded", "failed", "canceled", "aborted"]: output = self.output or [] new_output = output[len(previous_output) :] for item in new_output: @@ -282,7 +282,7 @@ async def async_output_iterator(self) -> AsyncIterator[Any]: await asyncio.sleep(self._client.poll_interval) # pylint: disable=no-member await self.async_reload() - if self.status == "failed": + if self.status in ("failed", "aborted"): raise ModelError(self) output = self.output or [] diff --git a/replicate/run.py b/replicate/run.py index e82ffb4..89a3fc4 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -77,7 +77,7 @@ def run( prediction.wait() - if prediction.status == "failed": + if prediction.status in ("failed", "aborted"): raise ModelError(prediction) # Return an iterator for the completed prediction when needed. @@ -147,7 +147,7 @@ async def async_run( await prediction.async_wait() - if prediction.status == "failed": + if prediction.status in ("failed", "aborted"): raise ModelError(prediction) # Return an iterator for completed output if the model has an output iterator array type. diff --git a/tests/test_run.py b/tests/test_run.py index 93f7248..6db5c96 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1076,3 +1076,88 @@ def _version_with_schema(id: str = "v1", output_schema: Optional[object] = None) }, }, } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_run_raises_on_aborted_prediction(async_flag, mock_replicate_api_token): + """ + Regression test: an 'aborted' prediction (server-side termination) must surface as + ModelError and must NOT cause wait() / async_wait() to poll forever. + + Before the fix, 'aborted' was not in the terminal-state list, so wait() looped + until the test timed out (issue #431) and run() silently returned None output. + """ + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("starting"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + return_value=httpx.Response( + 200, + json={**_prediction_with_status("aborted"), "error": "Prediction was aborted"}, + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema(), + ) + ) + router.route(host="api.replicate.com").pass_through() + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + with pytest.raises(ModelError) as excinfo: + if async_flag: + await client.async_run("test/example:v1", input={"text": "Hello, world!"}) + else: + client.run("test/example:v1", input={"text": "Hello, world!"}) + + assert excinfo.value.prediction.status == "aborted" + + +@pytest.mark.asyncio +async def test_prediction_wait_terminates_on_aborted(mock_replicate_api_token): + """ + Regression test: Prediction.wait() and async_wait() must exit immediately when + a prediction transitions to 'aborted', not loop forever. + """ + import replicate + from replicate.prediction import Prediction + + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="GET", path="/predictions/p1").mock( + return_value=httpx.Response( + 200, + json={**_prediction_with_status("aborted"), "error": "aborted by server"}, + ) + ) + router.route(host="api.replicate.com").pass_through() + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + prediction = Prediction(**_prediction_with_status("processing")) + prediction._client = client + + # wait() must return (not loop forever) when status flips to "aborted" + prediction.wait() + assert prediction.status == "aborted" + + # Reset and verify async_wait() also exits + prediction2 = Prediction(**_prediction_with_status("processing")) + prediction2._client = client + await prediction2.async_wait() + assert prediction2.status == "aborted"