med / retrieval.py
mgbam's picture
Update retrieval.py
c8255aa verified
import requests
import os
from chromadb.config import Settings
import chromadb
from transformers import AutoTokenizer, AutoModel
import torch
from typing import List
# Optional: Set your PubMed API key as an environment variable
PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>")
#############################################
# 1) FETCHING ABSTRACTS FROM PUBMED
#############################################
def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
"""
Fetches PubMed abstracts for the specified query using NCBI's E-utilities.
Returns a list of abstract texts.
"""
search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
params = {
"db": "pubmed",
"term": query,
"retmax": max_results,
"api_key": PUBMED_API_KEY,
"retmode": "json"
}
r = requests.get(search_url, params=params)
r.raise_for_status()
data = r.json()
pmid_list = data["esearchresult"].get("idlist", [])
abstracts = []
fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
for pmid in pmid_list:
fetch_params = {
"db": "pubmed",
"id": pmid,
"rettype": "abstract",
"retmode": "text",
"api_key": PUBMED_API_KEY
}
fetch_resp = requests.get(fetch_url, params=fetch_params)
fetch_resp.raise_for_status()
abstract_text = fetch_resp.text.strip()
if abstract_text:
abstracts.append(abstract_text)
return abstracts
#############################################
# 2) CHROMA VECTOR STORE SETUP
#############################################
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
class EmbedFunction:
"""
Wraps a Hugging Face embedding model and provides a __call__ method with the signature:
(self, input: List[str]) -> List[List[float]]
This is required by the latest Chroma embedding interface.
"""
def __init__(self, model_name: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
def __call__(self, input: List[str]) -> List[List[float]]:
"""
Batch-embeds a list of strings.
Returns one embedding per input string as a list of floats.
"""
if not input:
return []
tokenized = self.tokenizer(
input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = self.model(**tokenized, output_hidden_states=True)
last_hidden = outputs.hidden_states[-1] # shape: [batch_size, seq_len, hidden_dim]
pooled = last_hidden.mean(dim=1) # shape: [batch_size, hidden_dim]
embeddings = pooled.cpu().tolist() # List[List[float]]
return embeddings
# Instantiate the embedding function (using the correct parameter name "input")
embed_function = EmbedFunction(EMBED_MODEL_NAME)
# Create the Chroma client with updated settings
client = chromadb.Client(
settings=Settings(
persist_directory="chromadb_data",
anonymized_telemetry=False
)
)
# Create or get the collection using the updated embed_function
collection = client.get_or_create_collection(
name="pubmed_abstracts",
embedding_function=embed_function
)
def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
"""
Adds documents to Chroma with unique IDs.
"""
for i, doc in enumerate(docs):
if doc.strip():
doc_id = f"{prefix}-{i}"
collection.add(documents=[doc], ids=[doc_id])
def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
"""
Retrieves the top_k similar documents from Chroma based on embedding distance.
"""
results = collection.query(query_texts=[query], n_results=top_k)
return results["documents"][0] if results and results["documents"] else []
#############################################
# 3) MAIN RETRIEVAL PIPELINE
#############################################
def get_relevant_pubmed_docs(user_query: str) -> List[str]:
"""
End-to-end pipeline:
1. Fetches PubMed abstracts for the user query.
2. Indexes them in Chroma.
3. Retrieves the top relevant documents.
"""
new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
if not new_abstracts:
return []
index_pubmed_docs(new_abstracts, prefix=user_query)
top_docs = query_similar_docs(user_query, top_k=3)
return top_docs