Spaces:
Sleeping
Sleeping
File size: 4,588 Bytes
8042e59 9d526c3 8042e59 9d526c3 8042e59 9d526c3 8042e59 9d526c3 8042e59 9d526c3 8042e59 9d526c3 8042e59 9d526c3 8042e59 9d526c3 8042e59 d2b9f46 8042e59 d2b9f46 8042e59 9d526c3 8042e59 9d526c3 8042e59 6965281 8042e59 6965281 8042e59 9d526c3 8042e59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
# %%
import os
import json
import torch
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import (
pipeline,
TextGenerationPipeline,
AutoModelForCausalLM,
AutoTokenizer,
)
HF_TOKEN = os.environ["hf_token"]
SYSTEM_PROMPT = """You are a helpful question answering assistant. You will be given a context and a question. You need to provide the answer to the question based on the context. Answer briefly, based on the context. Only output the answer, and nothing else. Here is an example:
>> Context
Fascin is an actin-bundling protein that induces membrane protrusions and cell motility after the formation of lamellipodia or filopodia. Fascin expression has been associated with progression or prognosis in various neoplasms; however, its role in intrahepatic cholangiocarcinoma is unknown.
>> Question
What type of protein is fascin?
>> Answer
Actin-bundling protein
Now answer the user's question based on the user's given context.
"""
USER_PROMPT = """
>> Context
{context}
>> Question
{question}
>> Answer
"""
def load_embedder(model_path: str, device: str) -> SentenceTransformer:
embedder = SentenceTransformer(model_path)
embedder.to(device)
return embedder
def load_contexts(context_file: str) -> list[str]:
contexts = []
with open(context_file, "r") as f_in:
for line in f_in:
context = json.loads(line)
contexts.append(context["context"])
return contexts
def load_index(index_file: str) -> faiss.Index:
return faiss.read_index(index_file)
def load_reader(model_path: str, device: str) -> TextGenerationPipeline:
model = AutoModelForCausalLM.from_pretrained(model_path, token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN)
tokenizer.pad_token = tokenizer.eos_token
reader = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
token=HF_TOKEN,
device=device,
)
return reader
def construct_prompt(contexts: list[str], question: str) -> list[dict]:
return [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": USER_PROMPT.format(
context="\n".join(contexts), question=question
),
},
]
def load_all(
embedder_path: str,
context_file: str,
index_file: str,
reader_path: str,
) -> tuple[SentenceTransformer, list[str], faiss.Index, TextGenerationPipeline]:
embedder = load_embedder(embedder_path, "cpu")
contexts = load_contexts(context_file)
index = load_index(index_file)
reader_device = "cuda" if torch.cuda.is_available() else "cpu"
reader = load_reader(reader_path, reader_device)
return {
"embedder": embedder,
"contexts": contexts,
"index": index,
"reader": reader,
}
def run_query(
question: str,
embedder: SentenceTransformer,
index: faiss.Index,
contexts: list[str],
reader: TextGenerationPipeline,
top_k: int = 3,
) -> tuple[list[int], list[str], str]:
query_embedding = embedder.encode([question], normalize_embeddings=True)
_, retrieved_context_ids = index.search(query_embedding, top_k)
retrieved_context_ids = np.array(retrieved_context_ids) # shape: (1, top_k)
retrieved_contexts = []
for row in retrieved_context_ids:
retrieved_contexts.append(
[contexts[i] if contexts[i] is not None else "" for i in row]
)
# The code below is for a single question.
prompt = construct_prompt(retrieved_contexts[0], question)
answer = reader(prompt, max_new_tokens=128, return_full_text=False)
print(answer)
answer_text = answer[0]["generated_text"]
if ">> Answer" in answer_text:
answer_text = answer_text.split(">> Answer")[1].strip()
return retrieved_context_ids[0].tolist(), retrieved_contexts[0], answer_text
# %%
# embedder_path = "Snowflake/snowflake-arctic-embed-l"
# reader_path = "meta-llama/Llama-3.2-1B-Instruct"
# context_file = "../data/bioasq_contexts.jsonl"
# index_file = "../data/bioasq_contexts__snowflake-arctic-embed-l__float32_hnsw.index"
# embedder, contexts, index, reader = load_all(
# embedder_path, "cpu", context_file, index_file, reader_path, "mps"
# )
# query = "What cellular structures does fascin induce?"
# retrieved_context_ids, retrieved_contexts, answer_text = run_query(
# query, embedder, index, contexts, reader
# )
# %%
|