# %% import os import json import torch import faiss import numpy as np from sentence_transformers import SentenceTransformer from transformers import ( pipeline, TextGenerationPipeline, AutoModelForCausalLM, AutoTokenizer, ) HF_TOKEN = os.environ["hf_token"] SYSTEM_PROMPT = """You are a helpful question answering assistant. You will be given a context and a question. You need to provide the answer to the question based on the context. Answer briefly, based on the context. Only output the answer, and nothing else. Here is an example: >> Context Fascin is an actin-bundling protein that induces membrane protrusions and cell motility after the formation of lamellipodia or filopodia. Fascin expression has been associated with progression or prognosis in various neoplasms; however, its role in intrahepatic cholangiocarcinoma is unknown. >> Question What type of protein is fascin? >> Answer Actin-bundling protein Now answer the user's question based on the user's given context. """ USER_PROMPT = """ >> Context {context} >> Question {question} >> Answer """ def load_embedder(model_path: str, device: str) -> SentenceTransformer: embedder = SentenceTransformer(model_path) embedder.to(device) return embedder def load_contexts(context_file: str) -> list[str]: contexts = [] with open(context_file, "r") as f_in: for line in f_in: context = json.loads(line) contexts.append(context["context"]) return contexts def load_index(index_file: str) -> faiss.Index: return faiss.read_index(index_file) def load_reader(model_path: str, device: str) -> TextGenerationPipeline: model = AutoModelForCausalLM.from_pretrained(model_path, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN) tokenizer.pad_token = tokenizer.eos_token reader = pipeline( "text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16, token=HF_TOKEN, device=device, ) return reader def construct_prompt(contexts: list[str], question: str) -> list[dict]: return [ {"role": "system", "content": SYSTEM_PROMPT}, { "role": "user", "content": USER_PROMPT.format( context="\n".join(contexts), question=question ), }, ] def load_all( embedder_path: str, context_file: str, index_file: str, reader_path: str, ) -> tuple[SentenceTransformer, list[str], faiss.Index, TextGenerationPipeline]: embedder = load_embedder(embedder_path, "cpu") contexts = load_contexts(context_file) index = load_index(index_file) reader_device = "cuda" if torch.cuda.is_available() else "cpu" reader = load_reader(reader_path, reader_device) return { "embedder": embedder, "contexts": contexts, "index": index, "reader": reader, } def run_query( question: str, embedder: SentenceTransformer, index: faiss.Index, contexts: list[str], reader: TextGenerationPipeline, top_k: int = 3, ) -> tuple[list[int], list[str], str]: query_embedding = embedder.encode([question], normalize_embeddings=True) _, retrieved_context_ids = index.search(query_embedding, top_k) retrieved_context_ids = np.array(retrieved_context_ids) # shape: (1, top_k) retrieved_contexts = [] for row in retrieved_context_ids: retrieved_contexts.append( [contexts[i] if contexts[i] is not None else "" for i in row] ) # The code below is for a single question. prompt = construct_prompt(retrieved_contexts[0], question) answer = reader(prompt, max_new_tokens=128, return_full_text=False) print(answer) answer_text = answer[0]["generated_text"] if ">> Answer" in answer_text: answer_text = answer_text.split(">> Answer")[1].strip() return retrieved_context_ids[0].tolist(), retrieved_contexts[0], answer_text # %% # embedder_path = "Snowflake/snowflake-arctic-embed-l" # reader_path = "meta-llama/Llama-3.2-1B-Instruct" # context_file = "../data/bioasq_contexts.jsonl" # index_file = "../data/bioasq_contexts__snowflake-arctic-embed-l__float32_hnsw.index" # embedder, contexts, index, reader = load_all( # embedder_path, "cpu", context_file, index_file, reader_path, "mps" # ) # query = "What cellular structures does fascin induce?" # retrieved_context_ids, retrieved_contexts, answer_text = run_query( # query, embedder, index, contexts, reader # ) # %%