from itertools import batched
from celery import shared_task
from loguru import logger
from magpie.config import EMBED_BATCH_COUNT
from magpie.datamodel import DataFetcher, Folder, SemanticEmbedding, SemanticInformation, SemanticModel, Url
from magpie.fetchers import all_fetchers, get_fetcher
try:
from magpie.semantic.embedder import Embedder
except ModuleNotFoundError:
pass
from magpie.task.monitor import wait_for_tasks_completion
from magpie.util import short_str
[docs]
@shared_task(rate_limit='100/m')
def fetch_additional_info(url: Url):
fetcher = url.url_type
assert url.url_type is not None, 'Can only fetch additional info for identified URLs'
logger.debug(f'Using {fetcher} to fetch info for: {url}')
result = get_fetcher(fetcher).fetch_additional_info(url)
if result:
logger.debug(f'Fetched content for url {url}: {short_str(result)}')
return result
else:
logger.warning(f'Got no result for url {url} using fetcher: {fetcher}')
[docs]
@shared_task(rate_limit='100/m')
def embed_url_sync(url: Url) -> tuple[SemanticModel, list]:
"""
Computes url embedding, synchronous
"""
snapshot = url.content.snapshot()
logger.debug(f'Computing embedding for url {url}')
# TODO: Embedder to be init before hand and CUDA/GPU context to be reused correctly
# MAG-50
emb = Embedder()
sem, res = emb.encode(snapshot)
# TODO: ATM we convert embedding tensors into list[float] because MAG-57 is not done
emb_list = [float(f) for f in res]
return sem, emb_list
[docs]
@shared_task(rate_limit='100/m')
def embed_url_batch_sync(urls: list[Url]):
"""
Computes embeddings for a batch of URLs, synchronous
"""
snapshots = [url.content.snapshot() for url in urls]
logger.debug(f'Computing embedding for urls {urls}')
# TODO: Embedder to be init before hand and CUDA/GPU context to be reused correctly
# MAG-50
emb = Embedder()
sem, res = emb.encode_list(snapshots)
logger.debug(f"computed embeddings shape {res.shape}")
# TODO: ATM we convert embedding tensors into list[float] because MAG-57 is not done
emb_list = [[float(f) for f in res[i]] for i in range(res.shape[0])]
return sem, emb_list
[docs]
class DataRetriever:
"""The `DataRetriever` is the main class that is used to fetch additional
data/content for a given twig or folder of twigs.
It contains 3 main methods, which will gather the 3 types of information a
`Url` can have:
- `identify()`: parses the URL and identifies a fetcher for it, as well as
extracting the information it can from the URL only. This yields a subclass
of `UrlInformation`. This operation is synchronous and run on the main process
as it should be fast.
- `fetch()`: fetches additional content (web page download, additional resources)
for a given URL which type has been previously identified. This yields a subclass
of `ContentInformation`. This is asynchronous and sent to the task queue.
- `get_semantic_info()`: runs LLMs or other AI models to extract semantic information.
This yields a `SemanticInformation` instance. This is asynchronous and sent to the
task queue.
For all asynchronous tasks, you can still decide to wait on their completion by
calling the `wait_for_tasks_completion()` method.
"""
[docs]
def get_fetcher_for(self, url: Url) -> DataFetcher | None:
if url.url_type:
return get_fetcher(url.url_type)
else:
for fetcher in all_fetchers():
if fetcher.match(url):
return fetcher
return None
[docs]
def identify_url(self, url: Url):
"""Try all the registered fetchers and see if they match the given URL.
If they do, then add the extracted info to it.
"""
fetcher = self.get_fetcher_for(url)
if fetcher is None:
return
url.url_type = fetcher.name()
url.info = url.info or fetcher.extract_info(url)
logger.debug(f'Got information for url {url}: {url.info}')
[docs]
def identify(self, folder: Folder):
"""Take a folder as input and try to identify the types of URLs in
all the Twigs in that folder (and subfolders).
"""
for url in folder.iter_urls():
self.identify_url(url)
[docs]
def fetch_url(self, url: Url, callback=None):
"""Try all the registered fetchers and see if they match the given URL.
If they do, fetch additional info about the URL and add it to it.
"""
from magpie.task.queue import run_task
fetcher = self.get_fetcher_for(url)
if fetcher is None:
return
def cb(result):
if result is not None:
logger.debug(f'for {url} got {type(result)}:\n{result.as_text()}')
url.content = result
else:
logger.warning(f'for {url} got no result! (result = None)')
if callback is not None:
callback(url)
if not url.content:
run_task(fetch_additional_info, url,
callback=cb)
[docs]
def fetch(self, folder: Folder, callback=None):
"""Take a folder as input and fetches the content of the URLs for
all the Twigs in that folder (and subfolders).
If provided, the callback must have the following signature:
`def callback(int, int, str)` where the args are respectively:
(completed, total, msg)
Note: this must be done after `identify()` has been called. Based on the identified URLs,
the appropriate fetcher plugin will be used
"""
urls = list(url for url in folder.iter_urls() if url.url_type is not None)
self.apply_with_callback(self.fetch_url, callback, urls, 'Downloaded content for {}')
[docs]
def apply_with_callback(self, method, callback, urls, msg):
finished, total = 0, len(urls)
def cb(url):
nonlocal finished
finished += 1
if callback is not None:
callback(finished, total, f'Downloaded content for {url.value}')
for url in urls:
method(url, callback=cb)
[docs]
def wait_for_tasks_completion(self, timeout=None):
wait_for_tasks_completion(timeout=timeout)
[docs]
def expand_data(self, folder: Folder):
"""Take a folder as input and passes it to all registered Fetchers so
they can call their own `DataFetcher.expand_data(folder)` method on it.
This allows the fetchers to not only find more information for
a single URL, but to manipulate the whole database in order to
reorganize it in case they need to.
It is a good idea to call `DataRetriever.identify(folder)` and
`DataRetriever.fetch(folder)` after this in order to get content for
additional twigs that might have been created.
"""
for fetcher in all_fetchers():
fetcher.expand_data(folder)
[docs]
def get_semantic_info(self, folder: Folder, callback=None):
"""
Computes url embedding based on downloaded content for all urls in folder
If provided, the callback must have the following signature:
`def callback(int, int, str)` where the args are respectively:
(completed, total, msg)
Note: this must be done after `fetch()` has been called.
"""
urls = list(url for url in folder.iter_urls() if url.content is not None)
logger.warning("get_semantic_info called")
if EMBED_BATCH_COUNT is not None:
finished, total = 0, len(urls)
def cb(urls):
nonlocal finished
finished += len(urls)
if callback is not None:
callback(finished, total, f'Embedded content for {urls}')
logger.info("Embedding by batches")
for batch in batched(urls, EMBED_BATCH_COUNT):
self.embed_url_batch(batch, callback=cb)
else:
self.apply_with_callback(self.embed_url, callback, urls, 'Embedded content for {}')
[docs]
def embed_url(self, url: Url, callback=None):
"""
Computes url embedding based on downloaded content for given url in folder
Launches celery task and get results in callback
"""
from magpie.task.queue import run_task
def cb(result):
logger.warning("embed_url cb called")
if result is not None:
model, embedding = result
url.semantic = SemanticInformation(tags=None,
summary=None,
embedding=SemanticEmbedding(model=model,
content=embedding)
)
else:
logger.warning(f'embedding failed for url {url})')
if callback is not None:
callback(url)
logger.warning("embed_url")
run_task(embed_url_sync, url, callback=cb)
[docs]
def embed_url_batch(self, urls: list[Url], callback=None):
"""
Computes url embedding based on downloaded content for given url in folder
Launches celery task and get results in callback
"""
from magpie.task.queue import run_task
def cb(result):
logger.warning("embed_url_batch cb called")
if result is not None:
model, embeddings = result
# TODO: iterate on array first dim when MAG-57 implemented
for url, emb in zip(urls, embeddings):
url.semantic = SemanticInformation(tags=None,
summary=None,
embedding=SemanticEmbedding(model=model,
content=emb)
)
else:
logger.warning(f'embedding failed for urls {urls})')
if callback is not None:
callback(urls)
logger.warning("embed_url_batch")
run_task(embed_url_batch_sync, urls, callback=cb)