|
import requests |
|
import os |
|
from chromadb.config import Settings |
|
import chromadb |
|
from transformers import AutoTokenizer, AutoModel |
|
import torch |
|
from typing import List |
|
|
|
|
|
PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>") |
|
|
|
|
|
|
|
|
|
def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]: |
|
""" |
|
Fetches PubMed abstracts for the specified query using NCBI's E-utilities. |
|
Returns a list of abstract texts. |
|
""" |
|
search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" |
|
params = { |
|
"db": "pubmed", |
|
"term": query, |
|
"retmax": max_results, |
|
"api_key": PUBMED_API_KEY, |
|
"retmode": "json" |
|
} |
|
r = requests.get(search_url, params=params) |
|
r.raise_for_status() |
|
data = r.json() |
|
|
|
pmid_list = data["esearchresult"].get("idlist", []) |
|
abstracts = [] |
|
fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi" |
|
for pmid in pmid_list: |
|
fetch_params = { |
|
"db": "pubmed", |
|
"id": pmid, |
|
"rettype": "abstract", |
|
"retmode": "text", |
|
"api_key": PUBMED_API_KEY |
|
} |
|
fetch_resp = requests.get(fetch_url, params=fetch_params) |
|
fetch_resp.raise_for_status() |
|
abstract_text = fetch_resp.text.strip() |
|
if abstract_text: |
|
abstracts.append(abstract_text) |
|
return abstracts |
|
|
|
|
|
|
|
|
|
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
|
class EmbedFunction: |
|
""" |
|
Wraps a Hugging Face embedding model and provides a __call__ method with the signature: |
|
(self, input: List[str]) -> List[List[float]] |
|
This is required by the latest Chroma embedding interface. |
|
""" |
|
def __init__(self, model_name: str): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModel.from_pretrained(model_name) |
|
self.model.eval() |
|
|
|
def __call__(self, input: List[str]) -> List[List[float]]: |
|
""" |
|
Batch-embeds a list of strings. |
|
Returns one embedding per input string as a list of floats. |
|
""" |
|
if not input: |
|
return [] |
|
tokenized = self.tokenizer( |
|
input, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=512 |
|
) |
|
with torch.no_grad(): |
|
outputs = self.model(**tokenized, output_hidden_states=True) |
|
last_hidden = outputs.hidden_states[-1] |
|
pooled = last_hidden.mean(dim=1) |
|
embeddings = pooled.cpu().tolist() |
|
return embeddings |
|
|
|
|
|
embed_function = EmbedFunction(EMBED_MODEL_NAME) |
|
|
|
|
|
client = chromadb.Client( |
|
settings=Settings( |
|
persist_directory="chromadb_data", |
|
anonymized_telemetry=False |
|
) |
|
) |
|
|
|
|
|
collection = client.get_or_create_collection( |
|
name="pubmed_abstracts", |
|
embedding_function=embed_function |
|
) |
|
|
|
def index_pubmed_docs(docs: List[str], prefix: str = "doc"): |
|
""" |
|
Adds documents to Chroma with unique IDs. |
|
""" |
|
for i, doc in enumerate(docs): |
|
if doc.strip(): |
|
doc_id = f"{prefix}-{i}" |
|
collection.add(documents=[doc], ids=[doc_id]) |
|
|
|
def query_similar_docs(query: str, top_k: int = 3) -> List[str]: |
|
""" |
|
Retrieves the top_k similar documents from Chroma based on embedding distance. |
|
""" |
|
results = collection.query(query_texts=[query], n_results=top_k) |
|
return results["documents"][0] if results and results["documents"] else [] |
|
|
|
|
|
|
|
|
|
def get_relevant_pubmed_docs(user_query: str) -> List[str]: |
|
""" |
|
End-to-end pipeline: |
|
1. Fetches PubMed abstracts for the user query. |
|
2. Indexes them in Chroma. |
|
3. Retrieves the top relevant documents. |
|
""" |
|
new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5) |
|
if not new_abstracts: |
|
return [] |
|
index_pubmed_docs(new_abstracts, prefix=user_query) |
|
top_docs = query_similar_docs(user_query, top_k=3) |
|
return top_docs |
|
|