diff --git a/docs/guide/cli.md b/docs/guide/cli.md index b70ad70a..66c8510d 100644 --- a/docs/guide/cli.md +++ b/docs/guide/cli.md @@ -138,6 +138,7 @@ The number of signals before a hard kill can be configured with the `--hardkill- * `--log-level` is used to set a log level (default `INFO`). * `--log-format` is used to set a log format (default `%(asctime)s][%(name)s][%(levelname)-7s][%(processName)s] %(message)s`). * `--max-async-tasks` - maximum number of simultaneously running async tasks. +* `--max-async-tasks-jitter` – Randomly varies the max async task limit between --max-async-tasks and a jittered value, helping prevent simultaneous worker restarts. * `--max-prefetch` - number of tasks to be prefetched before execution. (Useful for systems with high message rates, but brokers should support acknowledgements). * `--max-threadpool-threads` - number of threads for sync function execution. * `--no-propagate-errors` - if this parameter is enabled, exceptions won't be thrown in generator dependencies. diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index c0902371..ea2e86c0 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -533,3 +533,11 @@ def _register_task( if task.broker != self: raise TaskBrokerMismatchError(broker=task.broker) self.local_task_registry[task_name] = task + + async def __aenter__(self) -> None: + """Starts the broker as ctx manager.""" + await self.startup() + + async def __aexit__(self, *args: object, **kwargs: Any) -> None: + """Shuts down the broker as ctx manager.""" + await self.shutdown() diff --git a/taskiq/api/receiver.py b/taskiq/api/receiver.py index 72c6cbcf..03b84ca3 100644 --- a/taskiq/api/receiver.py +++ b/taskiq/api/receiver.py @@ -15,6 +15,7 @@ async def run_receiver_task( sync_workers: int | None = None, validate_params: bool = True, max_async_tasks: int = 100, + max_async_tasks_jitter: int = 0, max_prefetch: int = 0, propagate_exceptions: bool = True, run_startup: bool = False, @@ -43,6 +44,7 @@ async def run_receiver_task( or processes in processpool that runs sync tasks. :param validate_params: whether to validate params or not. :param max_async_tasks: maximum number of simultaneous async tasks. + :param max_async_tasks_jitter: random jitter to add to max_async_tasks. :param max_prefetch: maximum number of tasks to prefetch. :param propagate_exceptions: whether to propagate exceptions in generators or not. :param run_startup: whether to run startup function or not. @@ -79,6 +81,7 @@ def on_exit(_: Receiver) -> None: run_startup=run_startup, validate_params=validate_params, max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, max_prefetch=max_prefetch, propagate_exceptions=propagate_exceptions, on_exit=on_exit, diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index b7d1e67e..0a7cc98e 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -127,6 +127,7 @@ def __init__( max_stored_results: int = 100, cast_types: bool = True, max_async_tasks: int = 30, + max_async_tasks_jitter: int = 0, propagate_exceptions: bool = True, await_inplace: bool = False, ) -> None: @@ -140,6 +141,7 @@ def __init__( executor=self.executor, validate_params=cast_types, max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, propagate_exceptions=propagate_exceptions, ) self.await_inplace = await_inplace diff --git a/taskiq/cli/worker/args.py b/taskiq/cli/worker/args.py index fa922d7d..df3eb957 100644 --- a/taskiq/cli/worker/args.py +++ b/taskiq/cli/worker/args.py @@ -44,6 +44,7 @@ class WorkerArgs: reload_dirs: list[str] = field(default_factory=list) no_gitignore: bool = False max_async_tasks: int = 100 + max_async_tasks_jitter: int = 0 receiver: str = "taskiq.receiver:Receiver" receiver_arg: list[tuple[str, str]] = field(default_factory=list) max_prefetch: int = 0 @@ -210,6 +211,14 @@ def from_cli( default=100, help="Maximum simultaneous async tasks per worker process. ", ) + parser.add_argument( + "--max-async-tasks-jitter", + type=int, + dest="max_async_tasks_jitter", + default=0, + help="Add random jitter (0 to this value) to max-async-tasks to prevent " + "all workers from closing at the same time. ", + ) parser.add_argument( "--max-prefetch", type=int, diff --git a/taskiq/cli/worker/run.py b/taskiq/cli/worker/run.py index 53cef7c0..24d8f8db 100644 --- a/taskiq/cli/worker/run.py +++ b/taskiq/cli/worker/run.py @@ -165,6 +165,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None: executor=pool, validate_params=not args.no_parse, max_async_tasks=args.max_async_tasks, + max_async_tasks_jitter=args.max_async_tasks_jitter, max_prefetch=args.max_prefetch, propagate_exceptions=not args.no_propagate_errors, ack_type=args.ack_type, diff --git a/taskiq/middlewares/opentelemetry_middleware.py b/taskiq/middlewares/opentelemetry_middleware.py index 6ebbe10e..f90c1299 100644 --- a/taskiq/middlewares/opentelemetry_middleware.py +++ b/taskiq/middlewares/opentelemetry_middleware.py @@ -1,8 +1,11 @@ import logging +from collections.abc import Generator from contextlib import AbstractContextManager +from datetime import datetime, timezone from importlib.metadata import version from typing import Any, TypeVar +import psutil from packaging.version import Version, parse try: @@ -16,7 +19,7 @@ from opentelemetry import context as context_api from opentelemetry import trace -from opentelemetry.metrics import Meter, MeterProvider, get_meter +from opentelemetry.metrics import Meter, MeterProvider, Observation, get_meter from opentelemetry.propagate import extract, inject from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import Span, Tracer, TracerProvider @@ -59,6 +62,9 @@ _TASK_RETRY_REASON_KEY = "taskiq.retry.reason" _TASK_NAME_KEY = "taskiq.task_name" +_TASK_QUEUE_TIME_KEY = "_taskiq_queue_time" +_TASK_RECEIVED_TIME_KEY = "_taskiq_broker_receive_time" + def set_attributes_from_context(span: Span, context: dict[str, Any]) -> None: """Helper to extract meta values from a Taskiq Context.""" @@ -170,6 +176,74 @@ def __init__( if meter is None else meter ) + # Create metrics + # 1- Number of tasks sent. Producer (Counter) + self.n_tasks_sent_counter = self._meter.create_counter( + name="tasks_sent", + unit="1", + description="Number of tasks sent from the producer side", + ) + # 2- Number of errors by task name. consumer (Counter) + self.n_errors_counter = self._meter.create_counter( + name="task_errors", + unit="1", + description="Number of errors raised", + ) + # 3- Number of task successes. consumer (Counter) + self.n_success_counter = self._meter.create_counter( + name="task_success", + unit="1", + description="Number of tasks completed successfully", + ) + # 4- Task execution time. consumer (Histogram) + self.execution_time_hist = self._meter.create_histogram( + "task_execution_time", + unit="s", + description="Time to finish executing tasks", + ) + # 5- Task wait time. both (Histogram) + self.task_wait_time = self._meter.create_histogram( + "task_wait_time", + unit="s", + description="Time the tasks waited before executing", + ) + # current metrics to watch for in workers: CPU and memory utilization + self._process = psutil.Process() + # 6- CPU utilization + self.worker_cpu_utilization = self._meter.create_observable_gauge( + "worker_cpu_utilization", + callbacks=[self._observe_cpu], + unit="%", + description="Worker CPU utilization percentage. Only for worker processes", + ) + # 7- Memory utilization + self.worker_memory_utilization = self._meter.create_observable_gauge( + "worker_memory_utilization", + callbacks=[self._observe_memory], + unit="By", + description="Worker memory utilization in bytes. Only for worker processes", + ) + + # 8- Number of tasks executing + self.number_of_broker_active_tasks = self._meter.create_up_down_counter( + "worker_active_tasks", + unit="1", + description="Number of tasks currently executing in the worker.", + ) + # 9- Number of tasks executing + self.number_of_broker_prefetched_tasks = self._meter.create_up_down_counter( + "worker_prefetched_tasks", + unit="1", + description="Number of tasks currently prefetched in the worker.", + ) + + def _observe_memory(self, options: Any) -> Generator[Observation, None, None]: + if self.broker and self.broker.is_worker_process: + yield Observation(self._process.memory_info().rss) + + def _observe_cpu(self, options: Any) -> Generator[Observation, None, None]: + if self.broker and self.broker.is_worker_process: + yield Observation(self._process.cpu_percent()) def pre_send(self, message: TaskiqMessage) -> TaskiqMessage: """ @@ -193,7 +267,7 @@ def pre_send(self, message: TaskiqMessage) -> TaskiqMessage: activation.__enter__() attach_context(message, span, activation, None, is_publish=True) inject(message.labels) - + message.labels[_TASK_QUEUE_TIME_KEY] = datetime.now(timezone.utc).timestamp() return message def post_send(self, message: TaskiqMessage) -> None: @@ -214,6 +288,7 @@ def post_send(self, message: TaskiqMessage) -> None: activation.__exit__(None, None, None) detach_context(message, is_publish=True) + self.n_tasks_sent_counter.add(1, attributes={"task_name": message.task_name}) def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage: """ @@ -236,6 +311,11 @@ def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage: activation = trace.use_span(span, end_on_exit=True) activation.__enter__() # pylint: disable=E1101 attach_context(message, span, activation, token) + message.labels[_TASK_RECEIVED_TIME_KEY] = datetime.now(timezone.utc).timestamp() + self.number_of_broker_active_tasks.add( + 1, + attributes={"task_name": message.task_name}, + ) return message def post_save( # pylint: disable=R6301 @@ -313,3 +393,65 @@ def on_error( } span.record_exception(exception) span.set_status(Status(**status_kwargs)) # type: ignore[arg-type] + + def post_execute( + self, + message: "TaskiqMessage", + result: "TaskiqResult[Any]", + ) -> None: + """ + This function tracks number of errors and success executions. + + :param message: received message. + :param result: result of the execution. + """ + if result.is_err: + retry_on_error = message.labels.get("retry_on_error") + if isinstance(retry_on_error, str): + retry_on_error = retry_on_error.lower() == "true" + + if retry_on_error is None: + retry_on_error = False + + if retry_on_error: + # Add retry reason metadata to span + self.n_errors_counter.add( + 1, + attributes={"retry_error": True, "task_name": message.task_name}, + ) + else: + self.n_errors_counter.add( + 1, + attributes={"retry_error": False, "task_name": message.task_name}, + ) + else: + self.n_success_counter.add( + 1, + attributes={"task_name": message.task_name}, + ) + self.execution_time_hist.record( + result.execution_time, + attributes={ + "task_name": message.task_name, + }, + ) + task_receive_time = message.labels.get(_TASK_RECEIVED_TIME_KEY) + task_send_time = message.labels.get(_TASK_QUEUE_TIME_KEY) + if task_receive_time is not None and task_send_time is not None: + self.task_wait_time.record( + amount=task_receive_time - task_send_time, + attributes={"task_name": message.task_name}, + ) + + self.number_of_broker_active_tasks.add( + -1, + attributes={"task_name": message.task_name}, + ) + + def on_prefetch_queue_add(self) -> None: + """This hook is called after task is added to the worker prefetch queue.""" + self.number_of_broker_prefetched_tasks.add(1) + + def on_prefetch_queue_remove(self) -> None: + """This hook is called after task is removed from the worker prefetch queue.""" + self.number_of_broker_prefetched_tasks.add(-1) diff --git a/taskiq/middlewares/prometheus_middleware.py b/taskiq/middlewares/prometheus_middleware.py index 56837cf3..01f14867 100644 --- a/taskiq/middlewares/prometheus_middleware.py +++ b/taskiq/middlewares/prometheus_middleware.py @@ -84,22 +84,17 @@ def startup(self) -> None: This function starts prometheus server. It starts it only in case if it's a worker process. """ - from prometheus_client import ( # noqa: PLC0415 - CollectorRegistry, - start_http_server, - ) + from prometheus_client import REGISTRY, start_http_server # noqa: PLC0415 from prometheus_client.multiprocess import ( # noqa: PLC0415 MultiProcessCollector, ) if self.broker.is_worker_process: try: - registry = CollectorRegistry() - MultiProcessCollector(registry) + MultiProcessCollector(REGISTRY) start_http_server( port=self.server_port, addr=self.server_addr, - registry=registry, ) except OSError as exc: logger.debug("Cannot start prometheus server: %s", exc) diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 99298af2..5c2a1468 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -2,9 +2,10 @@ import contextvars import functools import inspect +import random import sys from collections.abc import Callable -from concurrent.futures import Executor +from concurrent.futures import Executor, ProcessPoolExecutor from logging import getLogger from time import time from typing import Any, get_type_hints @@ -28,6 +29,24 @@ QUEUE_DONE = b"-1" +def _execute_sync_task_in_executor( + target: Callable[..., Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + """Execute a sync task. + + This is a wrapper to ensure we pass the target function directly + to the executor, avoiding issues with pickling bound methods like ctx.run. + + :param target: function to execute + :param args: positional arguments + :param kwargs: keyword arguments + :return: result of the function call + """ + return target(*args, **kwargs) + + class Receiver: """Class that uses as a callback handler.""" @@ -37,6 +56,7 @@ def __init__( executor: Executor | None = None, validate_params: bool = True, max_async_tasks: "int | None" = None, + max_async_tasks_jitter: int = 0, max_prefetch: int = 0, propagate_exceptions: bool = True, run_startup: bool = True, @@ -62,13 +82,22 @@ def __init__( self._prepare_task(task.task_name, task.original_func) self.sem: asyncio.Semaphore | None = None if max_async_tasks is not None and max_async_tasks > 0: - self.sem = asyncio.Semaphore(max_async_tasks) + # Apply jitter to prevent all workers from hitting the limit simultaneously + actual_limit = max_async_tasks + if max_async_tasks_jitter > 0: + # Using standard random for load distribution, not cryptography + actual_limit = max_async_tasks + random.randint( # noqa: S311 + 0, + max_async_tasks_jitter, + ) + self.sem = asyncio.Semaphore(actual_limit) else: logger.warning( "Setting unlimited number of async tasks " "can result in undefined behavior", ) self.sem_prefetch = asyncio.Semaphore(max_prefetch) + self.is_process_pool = isinstance(executor, ProcessPoolExecutor) async def callback( # noqa: C901, PLR0912 self, @@ -245,15 +274,28 @@ async def run_task( # noqa: C901, PLR0912, PLR0915 target_future = target(*message.args, **kwargs) else: is_coroutine = False - # If this is a synchronous function, we - # run it in executor and preserve the context. - ctx = contextvars.copy_context() - func = functools.partial(target, *message.args, **kwargs) - target_future = loop.run_in_executor( - self.executor, - ctx.run, - func, - ) + if self.is_process_pool: + # For ProcessPoolExecutor, we can't use ctx.run because it contains + # a reference to contextvars.Context which cannot be pickled. + # Instead, we call the target function directly in the executor. + # Each worker process starts with its own context, so we don't need + # to preserve the parent context. + target_future = loop.run_in_executor( + self.executor, + _execute_sync_task_in_executor, + target, + tuple(message.args), + kwargs, + ) + else: + # For ThreadPoolExecutor, we can use ctx.run with functools.partial + ctx = contextvars.copy_context() + func = functools.partial(target, *message.args, **kwargs) + target_future = loop.run_in_executor( + self.executor, + ctx.run, + func, + ) timeout = message.labels.get("timeout") if timeout is not None: if not is_coroutine: @@ -383,6 +425,12 @@ async def prefetcher( current_message = asyncio.create_task(iterator.__anext__()) # type: ignore fetched_tasks += 1 await queue.put(message) + # Custom hooks for OTel and any future instrumentations + for middleware in reversed(self.broker.middlewares): + if hasattr(middleware, "on_prefetch_queue_add"): + await maybe_awaitable( + middleware.on_prefetch_queue_add(), # type: ignore + ) except (asyncio.CancelledError, StopAsyncIteration): break # We don't want to fetch new messages if we are shutting down. @@ -434,6 +482,13 @@ def task_cb(task: "asyncio.Task[Any]") -> None: logger.info("No more tasks to wait for. Shutting down.") break + # Custom hooks for OTel and any future instrumentations + for middleware in reversed(self.broker.middlewares): + if hasattr(middleware, "on_prefetch_queue_remove"): + await maybe_awaitable( + middleware.on_prefetch_queue_remove(), # type: ignore + ) + task = asyncio.create_task( self.callback(message=message, raise_err=False), ) diff --git a/taskiq/task.py b/taskiq/task.py index b4a0f52f..550dabee 100644 --- a/taskiq/task.py +++ b/taskiq/task.py @@ -98,9 +98,9 @@ async def wait_result( """ start_time = time() while not await self.is_ready(): - await asyncio.sleep(check_interval) if 0 < timeout < time() - start_time: raise TaskiqResultTimeoutError(timeout=timeout) + await asyncio.sleep(check_interval) return await self.get_result(with_logs=with_logs) async def get_progress(self) -> "TaskProgress[Any] | None": diff --git a/tests/abc/test_broker.py b/tests/abc/test_broker.py index 5c536cfb..636f9576 100644 --- a/tests/abc/test_broker.py +++ b/tests/abc/test_broker.py @@ -1,9 +1,13 @@ from collections.abc import AsyncGenerator from copy import copy +import pytest + from taskiq.abc.broker import AsyncBroker from taskiq.decor import AsyncTaskiqDecoratedTask +from taskiq.events import TaskiqEvents from taskiq.message import BrokerMessage +from taskiq.state import TaskiqState class _TestBroker(AsyncBroker): @@ -76,3 +80,80 @@ async def test_task() -> None: ... assert "another_label" in test_kicker.labels assert test_task.labels == old_labels + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("is_worker_process", "startup", "shutdown"), + [ + (True, TaskiqEvents.WORKER_STARTUP, TaskiqEvents.WORKER_SHUTDOWN), + (False, TaskiqEvents.CLIENT_STARTUP, TaskiqEvents.CLIENT_SHUTDOWN), + ], +) +async def test_async_context_manager_enter( + *, + is_worker_process: bool, + startup: TaskiqEvents, + shutdown: TaskiqEvents, +) -> None: + """Test that `__aenter__` and `__aexit__` calls work.""" + broker = _TestBroker() + broker.is_worker_process = is_worker_process + startup_called = False + shutdown_called = False + + @broker.on_event(startup) + async def track_startup(state: TaskiqState) -> None: + nonlocal startup_called + startup_called = True + + @broker.on_event(shutdown) + async def track_shutdown(state: TaskiqState) -> None: + nonlocal shutdown_called + shutdown_called = True + + async with broker as ctx: + assert ctx is None + assert startup_called is True + assert shutdown_called is False + + assert shutdown_called is True + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("is_worker_process", "startup", "shutdown"), + [ + (True, TaskiqEvents.WORKER_STARTUP, TaskiqEvents.WORKER_SHUTDOWN), + (False, TaskiqEvents.CLIENT_STARTUP, TaskiqEvents.CLIENT_SHUTDOWN), + ], +) +async def test_async_context_manager_exit_on_exception( + *, + is_worker_process: bool, + startup: TaskiqEvents, + shutdown: TaskiqEvents, +) -> None: + """Test that __aexit__ calls shutdown even if exception is raised.""" + broker = _TestBroker() + broker.is_worker_process = is_worker_process + startup_called = False + shutdown_called = False + + @broker.on_event(startup) + async def track_startup(state: TaskiqState) -> None: + nonlocal startup_called + startup_called = True + + @broker.on_event(shutdown) + async def track_shutdown(state: TaskiqState) -> None: + nonlocal shutdown_called + shutdown_called = True + + with pytest.raises(ValueError, match="Test exception"): + async with broker: + assert startup_called is True + assert shutdown_called is False + raise ValueError("Test exception") + + assert shutdown_called is True diff --git a/tests/opentelemetry/taskiq_test_tasks.py b/tests/opentelemetry/taskiq_test_tasks.py index d910313b..af2198f1 100644 --- a/tests/opentelemetry/taskiq_test_tasks.py +++ b/tests/opentelemetry/taskiq_test_tasks.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any from opentelemetry import baggage @@ -26,3 +27,8 @@ async def task_raises() -> None: @broker.task async def task_returns_baggage() -> Any: return dict(baggage.get_all()) + + +@broker.task +async def task_does_processing(wait_time: float) -> None: + await asyncio.sleep(wait_time) diff --git a/tests/opentelemetry/test_metrics.py b/tests/opentelemetry/test_metrics.py new file mode 100644 index 00000000..d12291fc --- /dev/null +++ b/tests/opentelemetry/test_metrics.py @@ -0,0 +1,224 @@ +import asyncio +from typing import Any + +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader +from opentelemetry.test.test_base import TestBase + +from taskiq.instrumentation import TaskiqInstrumentor +from taskiq.middlewares.opentelemetry_middleware import OpenTelemetryMiddleware + +from .taskiq_test_tasks import ( + broker, + task_add, + task_does_processing, + task_raises, +) + + +class TestTaskiqOTelMetrics(TestBase): + def setUp(self) -> None: + super().setUp() + self.reader = InMemoryMetricReader() + self.meter_provider = MeterProvider(metric_readers=[self.reader]) + TaskiqInstrumentor().instrument_broker( + broker, + meter_provider=self.meter_provider, + ) + + def tearDown(self) -> None: + super().tearDown() + TaskiqInstrumentor().uninstrument_broker(broker) + + def _get_data_points(self, metric_name: str) -> list[Any]: + metrics = self.reader.get_metrics_data() + if metrics is None: + return [] + return [ + point + for rm in metrics.resource_metrics + for sm in rm.scope_metrics + for metric in sm.metrics + if metric.name == metric_name + for point in metric.data.data_points + ] + + def test_metrics_exist(self) -> None: + async def test() -> None: + await task_add.kiq(1, 2) + await task_raises.kiq() + await broker.wait_all() + + asyncio.run(test()) + + metrics = self.reader.get_metrics_data() + self.assertIsNotNone(metrics) + expected = { + "task_errors", + "tasks_sent", + "task_success", + "task_execution_time", + "task_wait_time", + "worker_active_tasks", + } + found = { + metric.name + for rm in metrics.resource_metrics # type: ignore[union-attr] + for sm in rm.scope_metrics + for metric in sm.metrics + } + self.assertSetEqual(found.intersection(expected), expected) + + def test_success_counter(self) -> None: + async def test() -> None: + for _ in range(3): + await task_add.kiq(1, 2) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_success") + self.assertEqual(len(points), 1) + self.assertEqual(points[0].value, 3) + + def test_error_counter_no_retry(self) -> None: + async def test() -> None: + for _ in range(3): + await task_raises.kiq() + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_errors") + no_retry_points = [ + p for p in points if p.attributes.get("retry_error") is False + ] + self.assertEqual(len(no_retry_points), 1) + self.assertEqual(no_retry_points[0].value, 3) + + def test_error_counter_with_retry(self) -> None: + async def test() -> None: + for _ in range(3): + await task_raises.kicker().with_labels(retry_on_error="true").kiq() + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_errors") + retry_points = [p for p in points if p.attributes.get("retry_error") is True] + self.assertEqual(len(retry_points), 1) + self.assertEqual(retry_points[0].value, 3) + + def test_execution_time_histogram(self) -> None: + async def test() -> None: + for _ in range(3): + await task_does_processing.kiq(0.01) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_execution_time") + self.assertEqual(len(points), 1) + self.assertEqual(points[0].count, 3) + self.assertGreater(points[0].sum, 0) + + def test_task_wait_time_histogram(self) -> None: + async def test() -> None: + await task_does_processing.kiq(0.01) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_wait_time") + self.assertEqual(len(points), 1) + self.assertEqual(points[0].count, 1) + self.assertGreaterEqual(points[0].sum, 0) + + def test_queue_time(self) -> None: + async def test() -> None: + for _ in range(3): + await task_add.kiq(1, 2) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("task_wait_time") + # all 3 tasks share the same task_name so they aggregate into one data point + self.assertEqual(len(points), 1) + point = points[0] + # 3 tasks recorded + self.assertEqual(point.count, 3) + # queue time must be non-negative — a negative value means timestamps + # were not written/read correctly + self.assertGreaterEqual(point.sum, 0) + self.assertGreaterEqual(point.min, 0) + # task_name attribute must be present and correct + self.assertEqual( + point.attributes.get("task_name"), + "tests.opentelemetry.taskiq_test_tasks:task_add", + ) + + def test_tasks_sent_counter(self) -> None: + async def test() -> None: + for _ in range(3): + await task_add.kiq(1, 2) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("tasks_sent") + self.assertEqual(len(points), 1) + self.assertEqual(points[0].value, 3) + + def test_active_tasks_counter(self) -> None: + async def test() -> None: + for _ in range(3): + await task_add.kiq(1, 2) + await broker.wait_all() + + asyncio.run(test()) + + points = self._get_data_points("worker_active_tasks") + # all 3 tasks share the same task_name so they aggregate into one data point + self.assertEqual(len(points), 1) + # net zero: pre_execute incremented, post_execute decremented for each task + self.assertEqual(points[0].value, 0) + self.assertIn("task_name", points[0].attributes) + self.assertEqual( + points[0].attributes.get("task_name"), + "tests.opentelemetry.taskiq_test_tasks:task_add", + ) + + def test_prefetch_queue_counter(self) -> None: + middleware = next( + m for m in broker.middlewares if isinstance(m, OpenTelemetryMiddleware) + ) + middleware.on_prefetch_queue_add() + middleware.on_prefetch_queue_add() + middleware.on_prefetch_queue_add() + middleware.on_prefetch_queue_remove() + + points = self._get_data_points("worker_prefetched_tasks") + self.assertEqual(len(points), 1) + self.assertEqual(points[0].value, 2) + + def test_worker_resource_metrics_when_worker_process(self) -> None: + middleware = next( + m for m in broker.middlewares if isinstance(m, OpenTelemetryMiddleware) + ) + middleware.set_broker(broker) + broker.is_worker_process = True + try: + metrics_data = self.reader.get_metrics_data() + self.assertIsNotNone(metrics_data) + found = { + metric.name + for rm in metrics_data.resource_metrics # type: ignore[union-attr] + for sm in rm.scope_metrics + for metric in sm.metrics + } + self.assertIn("worker_cpu_utilization", found) + self.assertIn("worker_memory_utilization", found) + finally: + broker.is_worker_process = False + middleware.set_broker(None) # type: ignore[arg-type] diff --git a/tests/receiver/test_receiver.py b/tests/receiver/test_receiver.py index 0b0e976a..eeb29c11 100644 --- a/tests/receiver/test_receiver.py +++ b/tests/receiver/test_receiver.py @@ -2,6 +2,7 @@ import contextvars import random import time +import unittest.mock from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor from functools import wraps @@ -24,13 +25,15 @@ def get_receiver( broker: AsyncBroker | None = None, no_parse: bool = False, max_async_tasks: int | None = None, + max_async_tasks_jitter: int = 0, ) -> Receiver: """ Returns receiver with custom broker and args. :param broker: broker, defaults to None :param no_parse: parameter to taskiq_args, defaults to False - :param cli_args: Taskiq worker CLI arguments. + :param max_async_tasks: maximum number of simultaneous async tasks. + :param max_async_tasks_jitter: random jitter to add to max_async_tasks. :return: new receiver. """ if broker is None: @@ -40,6 +43,7 @@ def get_receiver( executor=ThreadPoolExecutor(max_workers=10), validate_params=not no_parse, max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, ) @@ -544,3 +548,55 @@ async def task_no_result() -> str: assert resp.return_value == "some value" assert not broker._running_tasks assert wrapper_call is True + + +async def test_jitter_applied_to_semaphore() -> None: + """Test that jitter is correctly applied to max_async_tasks semaphore.""" + max_async_tasks = 100 + max_async_tasks_jitter = 10 + + # Test with jitter value of 0 (minimum) + with unittest.mock.patch("random.randint", return_value=0): + receiver = get_receiver( + max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, + ) + assert receiver.sem is not None + assert receiver.sem._value == max_async_tasks + + # Test with jitter value of 5 (middle) + with unittest.mock.patch("random.randint", return_value=5): + receiver = get_receiver( + max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, + ) + assert receiver.sem is not None + assert receiver.sem._value == max_async_tasks + 5 + + # Test with jitter value of 10 (maximum) + with unittest.mock.patch("random.randint", return_value=10): + receiver = get_receiver( + max_async_tasks=max_async_tasks, + max_async_tasks_jitter=max_async_tasks_jitter, + ) + assert receiver.sem is not None + assert receiver.sem._value == max_async_tasks + 10 + + +async def test_jitter_zero_no_randomization() -> None: + """Test with zero jitter, semaphore value matches max_async_tasks.""" + max_async_tasks = 50 + + receiver = get_receiver( + max_async_tasks=max_async_tasks, + max_async_tasks_jitter=0, + ) + + assert receiver.sem is not None + assert receiver.sem._value == max_async_tasks + + +async def test_no_semaphore_without_max_async_tasks() -> None: + """Test that semaphore is None when max_async_tasks is not set.""" + receiver = get_receiver(max_async_tasks=None) + assert receiver.sem is None