File size: 4,785 Bytes
a240da9
126a4c6
a240da9
 
 
126a4c6
 
 
 
 
 
 
a240da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126a4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a240da9
 
 
 
 
 
 
126a4c6
a240da9
 
 
126a4c6
a240da9
 
 
126a4c6
a240da9
 
 
 
 
 
126a4c6
 
a240da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126a4c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a240da9
126a4c6
 
 
a240da9
 
 
 
 
 
 
 
 
126a4c6
a240da9
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import time
from dataclasses import dataclass

from base import JobInput
from db import get_db_cursor
from ml import (
    DefaultUrlProcessor,
    HfTransformersSummarizer,
    HfTransformersTagger,
    MlRegistry,
    RawTextProcessor,
)

SLEEP_INTERVAL = 5


def check_pending_jobs() -> list[JobInput]:
    """Check DB for pending jobs"""
    with get_db_cursor() as cursor:
        # fetch pending jobs, join authro and content from entries table
        query = """
        SELECT j.entry_id, e.author, e.source
        FROM jobs j
        JOIN entries e
        ON j.entry_id = e.id
        WHERE j.status = 'pending'
        """
        res = list(cursor.execute(query))
    return [
        JobInput(id=_id, author=author, content=content) for _id, author, content in res
    ]


@dataclass
class JobOutput:
    summary: str
    tags: list[str]
    processor_name: str
    summarizer_name: str
    tagger_name: str


def _process_job(job: JobInput, registry: MlRegistry) -> JobOutput:
    processor = registry.get_processor(job)
    processor_name = processor.get_name()
    processed = processor(job)

    tagger = registry.get_tagger()
    tagger_name = tagger.get_name()
    tags = tagger(processed)

    summarizer = registry.get_summarizer()
    summarizer_name = summarizer.get_name()
    summary = summarizer(processed)

    return JobOutput(
        summary=summary,
        tags=tags,
        processor_name=processor_name,
        summarizer_name=summarizer_name,
        tagger_name=tagger_name,
    )


def store(job: JobInput, output: JobOutput) -> None:
    with get_db_cursor() as cursor:
        # write to entries, summary, tags tables
        cursor.execute(
            (
                "INSERT INTO summaries (entry_id, summary, summarizer_name)"
                " VALUES (?, ?, ?)"
            ),
            (job.id, output.summary, output.summarizer_name),
        )
        cursor.executemany(
            "INSERT INTO tags (entry_id, tag, tagger_name) VALUES (?, ?, ?)",
            [(job.id, tag, output.tagger_name) for tag in output.tags],
        )


def process_job(job: JobInput, registry: MlRegistry) -> None:
    tic = time.perf_counter()
    print(f"Processing job for (id={job.id[:8]})")

    # care: acquire cursor (which leads to locking) as late as possible, since
    # the processing and we don't want to block other workers during that time
    try:
        output = _process_job(job, registry)
        store(job, output)
        # update job status to done
        with get_db_cursor() as cursor:
            cursor.execute(
                "UPDATE jobs SET status = 'done' WHERE entry_id = ?", (job.id,)
            )
    except Exception as e:
        # update job status to failed
        with get_db_cursor() as cursor:
            cursor.execute(
                "UPDATE jobs SET status = 'failed' WHERE entry_id = ?", (job.id,)
            )
        print(f"Failed to process job for (id={job.id[:8]}): {e}")

    toc = time.perf_counter()
    print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")


def load_mlregistry(model_name: str) -> MlRegistry:
    from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig

    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    config_summarizer = GenerationConfig.from_pretrained(model_name)
    config_summarizer.max_new_tokens = 200
    config_summarizer.min_new_tokens = 100
    config_summarizer.top_k = 5
    config_summarizer.repetition_penalty = 1.5

    config_tagger = GenerationConfig.from_pretrained(model_name)
    config_tagger.max_new_tokens = 50
    config_tagger.min_new_tokens = 25
    # increase the temperature to make the model more creative
    config_tagger.temperature = 1.5

    summarizer = HfTransformersSummarizer(model_name, model, tokenizer, config_summarizer)
    tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)

    registry = MlRegistry()
    registry.register_processor(DefaultUrlProcessor())
    registry.register_processor(RawTextProcessor())
    registry.register_summarizer(summarizer)
    registry.register_tagger(tagger)

    return registry


def main() -> None:
    model_name = "google/flan-t5-large"
    registry = load_mlregistry(model_name)

    while True:
        jobs = check_pending_jobs()
        if not jobs:
            print("No pending jobs found, sleeping...")
            time.sleep(SLEEP_INTERVAL)
            continue

        print(f"Found {len(jobs)} pending job(s), processing...")
        for job in jobs:
            process_job(job, registry)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("Shutting down...")
        exit(0)