Source code for magpie.task.queue
import time
import msgspec
from celery import Celery
from celery.signals import after_task_publish, task_postrun, task_prerun
from loguru import logger
# import tasks from these modules so they are registered in our Celery app
import magpie.fetch.cache # noqa: F401
import magpie.fetch.retriever # noqa: F401
import magpie.network # noqa: F401
from magpie import config, logging
from magpie.task import monitor, serializer
app = Celery('magpie.task.queue', broker=config.BROKER, backend=config.BACKEND)
"""Main Celery app instance"""
# add serializer for msgspec structs on our app
serializer.add_msgspec_serializer(app)
[docs]
def ping(timeout=0):
"""Ping the workers and return their responses.
Args:
timeout: seconds to wait for a reply. Useful when called just after starting a cluster
Returns:
a list of workers/responses available
"""
workers = app.control.ping()
while not workers and timeout:
# Note: use `time.sleep` here as `app.control.ping(timeout)` seems to return an empty list
# even though some workers are ready in the cluster by the end of the timeout.
# Maybe because they weren't connected yet when we first grabbed the handle to `app.control`?
time.sleep(min(1, timeout))
timeout = max(timeout - 1, 0)
workers = app.control.ping()
return workers
# Use this hook to keep a list of all tasks we send to the queue to allow monitoring
# them from within the main app
[docs]
@after_task_publish.connect
def task_sent_handler(sender=None, headers=None, body=None, **kwargs):
# information about task are located in headers for task messages
# using the task protocol version 2.
task_id = headers['id']
sig = headers['task'] + headers['argsrepr']
if (kw := headers['kwargsrepr']) != '{}':
sig += f' with kwargs = {kw}'
logger.debug(f'sent task {task_id}: {sig}')
monitor.add_task(task_id, None)
# These functions are run on the worker process, so we can't use them to callback
# inside the main application
# set globally the task_id to make it available to the loggers
[docs]
@task_prerun.connect
def task_prerun_handler(sender=None, task_id=None, task=None, *args, **kwargs):
# logger.debug(f'{task.request}')
logging.task_id = task_id
if config.WORKER_LOG_TO_FILE:
logging.task_log_file_handle = logger.add(
logging.logfile_for_task(task_id),
format=logging.custom_formatter,
level=0, # by default, log everything to file
)
# unset the task_id once a task is finished
[docs]
@task_postrun.connect
def task_postrun_handler(*args, **kwargs):
if config.WORKER_LOG_TO_FILE:
logger.remove(logging.task_log_file_handle)
logging.task_log_file_handle = None
logging.task_id = None
# @task_success.connect
# def task_success_handler(result=None, **kwargs):
# logger.warning(f'task successfully completed with result: {short_str(result)}')
#
# @task_failure.connect
# def task_failure_handler(**kwargs):
# logger.warning(f'task failed with args: {kwargs}')
[docs]
def run_task(task, *args, callback=None, **kwargs):
"""Run a task with an optional success callback.
The callback will receive the result of the task as its only argument.
"""
task_result = task.delay(*args, **kwargs)
monitor.add_task(task_result.id, callback)
# only define and register this task on the workers when running tests
# FIXME: find a way to move this into a test file (and have the worker properly import it)
if config.MAGPIE_TESTS:
class Point(msgspec.Struct):
x: int
y: int
@app.task
def square_point(p: Point) -> Point:
result = Point(x=p.x * p.x, y=p.y * p.y)
log_result(p, result)
return result
# note: this is a separate function to ensure that the task_id is properly bound to
# the logger and propagated to called functions (and not only in the task body)
# note2: actually, to be really effective this should be in a different module so they
# don't share the same logger instance (which is per-module)
def log_result(p, result):
logger.debug(f'square of {p} is {result}')