File size: 4,588 Bytes
8042e59
9d526c3
8042e59
 
 
9d526c3
8042e59
 
 
 
 
 
 
 
 
 
 
9d526c3
8042e59
 
9d526c3
8042e59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d526c3
8042e59
 
 
 
 
 
9d526c3
8042e59
 
 
 
 
 
 
 
 
 
9d526c3
8042e59
 
 
 
9d526c3
8042e59
d2b9f46
8042e59
d2b9f46
8042e59
 
 
 
 
 
 
 
 
 
 
 
 
 
9d526c3
8042e59
 
 
 
 
 
 
 
 
 
 
 
9d526c3
8042e59
 
 
 
 
 
6965281
8042e59
 
6965281
8042e59
 
 
 
 
 
 
 
 
 
9d526c3
8042e59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# %%

import os
import json


import torch
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import (
    pipeline,
    TextGenerationPipeline,
    AutoModelForCausalLM,
    AutoTokenizer,
)


HF_TOKEN = os.environ["hf_token"]


SYSTEM_PROMPT = """You are a helpful question answering assistant. You will be given a context and a question. You need to provide the answer to the question based on the context. Answer briefly, based on the context. Only output the answer, and nothing else. Here is an example:

>> Context
Fascin is an actin-bundling protein that induces membrane protrusions and cell motility after the formation of lamellipodia or filopodia. Fascin expression has been associated with progression or prognosis in various neoplasms; however, its role in intrahepatic cholangiocarcinoma is unknown.

>> Question
What type of protein is fascin?

>> Answer
Actin-bundling protein

Now answer the user's question based on the user's given context.
"""

USER_PROMPT = """
>> Context
{context}

>> Question
{question}

>> Answer
"""



def load_embedder(model_path: str, device: str) -> SentenceTransformer:
    embedder = SentenceTransformer(model_path)
    embedder.to(device)
    return embedder



def load_contexts(context_file: str) -> list[str]:
    contexts = []
    with open(context_file, "r") as f_in:
        for line in f_in:
            context = json.loads(line)
            contexts.append(context["context"])

    return contexts



def load_index(index_file: str) -> faiss.Index:
    return faiss.read_index(index_file)



def load_reader(model_path: str, device: str) -> TextGenerationPipeline:
    model = AutoModelForCausalLM.from_pretrained(model_path, token=HF_TOKEN)

    tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
    tokenizer.pad_token = tokenizer.eos_token

    reader = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        torch_dtype=torch.bfloat16,
        token=HF_TOKEN,
        device=device,
    )

    return reader



def construct_prompt(contexts: list[str], question: str) -> list[dict]:
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {
            "role": "user",
            "content": USER_PROMPT.format(
                context="\n".join(contexts), question=question
            ),
        },
    ]



def load_all(
    embedder_path: str,
    context_file: str,
    index_file: str,
    reader_path: str,
) -> tuple[SentenceTransformer, list[str], faiss.Index, TextGenerationPipeline]:
    embedder = load_embedder(embedder_path, "cpu")
    contexts = load_contexts(context_file)
    index = load_index(index_file)
    reader_device = "cuda" if torch.cuda.is_available() else "cpu"
    reader = load_reader(reader_path, reader_device)

    return {
        "embedder": embedder,
        "contexts": contexts,
        "index": index,
        "reader": reader,
    }



def run_query(
    question: str,
    embedder: SentenceTransformer,
    index: faiss.Index,
    contexts: list[str],
    reader: TextGenerationPipeline,
    top_k: int = 3,
) -> tuple[list[int], list[str], str]:
    query_embedding = embedder.encode([question], normalize_embeddings=True)
    _, retrieved_context_ids = index.search(query_embedding, top_k)
    retrieved_context_ids = np.array(retrieved_context_ids)  # shape: (1, top_k)

    retrieved_contexts = []
    for row in retrieved_context_ids:
        retrieved_contexts.append(
            [contexts[i] if contexts[i] is not None else "" for i in row]
        )

    # The code below is for a single question.
    prompt = construct_prompt(retrieved_contexts[0], question)
    answer = reader(prompt, max_new_tokens=128, return_full_text=False)
    print(answer)
    answer_text = answer[0]["generated_text"]
    if ">> Answer" in answer_text:
        answer_text = answer_text.split(">> Answer")[1].strip()

    return retrieved_context_ids[0].tolist(), retrieved_contexts[0], answer_text


# %%
# embedder_path = "Snowflake/snowflake-arctic-embed-l"
# reader_path = "meta-llama/Llama-3.2-1B-Instruct"
# context_file = "../data/bioasq_contexts.jsonl"
# index_file = "../data/bioasq_contexts__snowflake-arctic-embed-l__float32_hnsw.index"

# embedder, contexts, index, reader = load_all(
#     embedder_path, "cpu", context_file, index_file, reader_path, "mps"
# )

# query = "What cellular structures does fascin induce?"

# retrieved_context_ids, retrieved_contexts, answer_text = run_query(
#     query, embedder, index, contexts, reader
# )


# %%