File size: 3,490 Bytes
d660b02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from concurrent.futures import ThreadPoolExecutor, as_completed

from loguru import logger
from typing_extensions import Annotated
from clearml import PipelineDecorator

from llm_engineering.application import utils
from llm_engineering.domain.base.nosql import NoSQLBaseDocument
from llm_engineering.domain.documents import ArticleDocument, Document, PostDocument, RepositoryDocument, UserDocument


@PipelineDecorator.component(name="query_data_warehouse")
def query_data_warehouse(

    author_full_names: list[str],

) -> Annotated[list, "raw_documents"]:
    
    def fetch_all_data(user: UserDocument) -> dict[str, list[NoSQLBaseDocument]]:
        user_id = str(user.id)
        with ThreadPoolExecutor() as executor:
            future_to_query = {
                executor.submit(__fetch_articles, user_id): "articles",
                executor.submit(__fetch_posts, user_id): "posts",
                executor.submit(__fetch_repositories, user_id): "repositories",
            }

            results = {}
            for future in as_completed(future_to_query):
                query_name = future_to_query[future]
                try:
                    results[query_name] = future.result()
                except Exception:
                    logger.exception(f"'{query_name}' request failed.")

                    results[query_name] = []

        return results


    def __fetch_articles(user_id) -> list[NoSQLBaseDocument]:
        return ArticleDocument.bulk_find(author_id=user_id)


    def __fetch_posts(user_id) -> list[NoSQLBaseDocument]:
        return PostDocument.bulk_find(author_id=user_id)


    def __fetch_repositories(user_id) -> list[NoSQLBaseDocument]:
        return RepositoryDocument.bulk_find(author_id=user_id)


    def _get_metadata(documents: list[Document]) -> dict:
        metadata = {
            "num_documents": len(documents),
        }
        for document in documents:
            collection = document.get_collection_name()
            if collection not in metadata:
                metadata[collection] = {}
            if "authors" not in metadata[collection]:
                metadata[collection]["authors"] = list()

            metadata[collection]["num_documents"] = metadata[collection].get("num_documents", 0) + 1
            metadata[collection]["authors"].append(document.author_full_name)

        for value in metadata.values():
            if isinstance(value, dict) and "authors" in value:
                value["authors"] = list(set(value["authors"]))

        return metadata


    documents = []
    authors = []
    author_full_names = author_full_names if author_full_names is not None else []
    for author_full_name in author_full_names:
        logger.info(f"Querying data warehouse for user: {author_full_name}")

        first_name, last_name = utils.split_user_full_name(author_full_name)
        logger.info(f"First name: {first_name}, Last name: {last_name}")
        user = UserDocument.get_or_create(first_name=first_name, last_name=last_name)
        authors.append(user)

        results = fetch_all_data(user)
        user_documents = [doc for query_result in results.values() for doc in query_result]

        documents.extend(user_documents)

    #step_context = get_step_context()
    #step_context.add_output_metadata(output_name="raw_documents", metadata=_get_metadata(documents))

    return documents