Source code for magpie.task.queue

import time

import msgspec
from celery import Celery
from celery.signals import after_task_publish, task_postrun, task_prerun
from loguru import logger

# import tasks from these modules so they are registered in our Celery app
import magpie.fetch.cache  # noqa: F401
import magpie.fetch.retriever  # noqa: F401
import magpie.network  # noqa: F401
from magpie import config, logging
from magpie.task import monitor, serializer

app = Celery('magpie.task.queue', broker=config.BROKER, backend=config.BACKEND)
"""Main Celery app instance"""

# add serializer for msgspec structs on our app
serializer.add_msgspec_serializer(app)


[docs] def ping(timeout=0): """Ping the workers and return their responses. Args: timeout: seconds to wait for a reply. Useful when called just after starting a cluster Returns: a list of workers/responses available """ workers = app.control.ping() while not workers and timeout: # Note: use `time.sleep` here as `app.control.ping(timeout)` seems to return an empty list # even though some workers are ready in the cluster by the end of the timeout. # Maybe because they weren't connected yet when we first grabbed the handle to `app.control`? time.sleep(min(1, timeout)) timeout = max(timeout - 1, 0) workers = app.control.ping() return workers
# Use this hook to keep a list of all tasks we send to the queue to allow monitoring # them from within the main app
[docs] @after_task_publish.connect def task_sent_handler(sender=None, headers=None, body=None, **kwargs): # information about task are located in headers for task messages # using the task protocol version 2. task_id = headers['id'] sig = headers['task'] + headers['argsrepr'] if (kw := headers['kwargsrepr']) != '{}': sig += f' with kwargs = {kw}' logger.debug(f'sent task {task_id}: {sig}') monitor.add_task(task_id, None)
# These functions are run on the worker process, so we can't use them to callback # inside the main application # set globally the task_id to make it available to the loggers
[docs] @task_prerun.connect def task_prerun_handler(sender=None, task_id=None, task=None, *args, **kwargs): # logger.debug(f'{task.request}') logging.task_id = task_id if config.WORKER_LOG_TO_FILE: logging.task_log_file_handle = logger.add( logging.logfile_for_task(task_id), format=logging.custom_formatter, level=0, # by default, log everything to file )
# unset the task_id once a task is finished
[docs] @task_postrun.connect def task_postrun_handler(*args, **kwargs): if config.WORKER_LOG_TO_FILE: logger.remove(logging.task_log_file_handle) logging.task_log_file_handle = None logging.task_id = None
# @task_success.connect # def task_success_handler(result=None, **kwargs): # logger.warning(f'task successfully completed with result: {short_str(result)}') # # @task_failure.connect # def task_failure_handler(**kwargs): # logger.warning(f'task failed with args: {kwargs}')
[docs] def run_task(task, *args, callback=None, **kwargs): """Run a task with an optional success callback. The callback will receive the result of the task as its only argument. """ task_result = task.delay(*args, **kwargs) monitor.add_task(task_result.id, callback)
# only define and register this task on the workers when running tests # FIXME: find a way to move this into a test file (and have the worker properly import it) if config.MAGPIE_TESTS: class Point(msgspec.Struct): x: int y: int @app.task def square_point(p: Point) -> Point: result = Point(x=p.x * p.x, y=p.y * p.y) log_result(p, result) return result # note: this is a separate function to ensure that the task_id is properly bound to # the logger and propagated to called functions (and not only in the task body) # note2: actually, to be really effective this should be in a different module so they # don't share the same logger instance (which is per-module) def log_result(p, result): logger.debug(f'square of {p} is {result}')