File size: 4,697 Bytes
7a907a1 fbe007f 7a907a1 fbe007f 7a907a1 6960eaf 7a907a1 0cc52cd 7a907a1 664ce80 7a907a1 fbe007f 830f001 6960eaf 7a907a1 664ce80 0cc52cd c8255aa 0cc52cd 6960eaf fbe007f 0cc52cd 664ce80 6960eaf c8255aa fbe007f 0cc52cd 664ce80 c8255aa 664ce80 7a907a1 0cc52cd 1f0b244 2642a9d 1f0b244 2642a9d 830f001 664ce80 7a907a1 fbe007f 7a907a1 0cc52cd 7a907a1 fbe007f 7a907a1 2642a9d 7a907a1 6960eaf fbe007f 7a907a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import requests
import os
from chromadb.config import Settings
import chromadb
from transformers import AutoTokenizer, AutoModel
import torch
from typing import List
# Optional: Set your PubMed API key as an environment variable
PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>")
#############################################
# 1) FETCHING ABSTRACTS FROM PUBMED
#############################################
def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
"""
Fetches PubMed abstracts for the specified query using NCBI's E-utilities.
Returns a list of abstract texts.
"""
search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
params = {
"db": "pubmed",
"term": query,
"retmax": max_results,
"api_key": PUBMED_API_KEY,
"retmode": "json"
}
r = requests.get(search_url, params=params)
r.raise_for_status()
data = r.json()
pmid_list = data["esearchresult"].get("idlist", [])
abstracts = []
fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
for pmid in pmid_list:
fetch_params = {
"db": "pubmed",
"id": pmid,
"rettype": "abstract",
"retmode": "text",
"api_key": PUBMED_API_KEY
}
fetch_resp = requests.get(fetch_url, params=fetch_params)
fetch_resp.raise_for_status()
abstract_text = fetch_resp.text.strip()
if abstract_text:
abstracts.append(abstract_text)
return abstracts
#############################################
# 2) CHROMA VECTOR STORE SETUP
#############################################
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
class EmbedFunction:
"""
Wraps a Hugging Face embedding model and provides a __call__ method with the signature:
(self, input: List[str]) -> List[List[float]]
This is required by the latest Chroma embedding interface.
"""
def __init__(self, model_name: str):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
def __call__(self, input: List[str]) -> List[List[float]]:
"""
Batch-embeds a list of strings.
Returns one embedding per input string as a list of floats.
"""
if not input:
return []
tokenized = self.tokenizer(
input,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = self.model(**tokenized, output_hidden_states=True)
last_hidden = outputs.hidden_states[-1] # shape: [batch_size, seq_len, hidden_dim]
pooled = last_hidden.mean(dim=1) # shape: [batch_size, hidden_dim]
embeddings = pooled.cpu().tolist() # List[List[float]]
return embeddings
# Instantiate the embedding function (using the correct parameter name "input")
embed_function = EmbedFunction(EMBED_MODEL_NAME)
# Create the Chroma client with updated settings
client = chromadb.Client(
settings=Settings(
persist_directory="chromadb_data",
anonymized_telemetry=False
)
)
# Create or get the collection using the updated embed_function
collection = client.get_or_create_collection(
name="pubmed_abstracts",
embedding_function=embed_function
)
def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
"""
Adds documents to Chroma with unique IDs.
"""
for i, doc in enumerate(docs):
if doc.strip():
doc_id = f"{prefix}-{i}"
collection.add(documents=[doc], ids=[doc_id])
def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
"""
Retrieves the top_k similar documents from Chroma based on embedding distance.
"""
results = collection.query(query_texts=[query], n_results=top_k)
return results["documents"][0] if results and results["documents"] else []
#############################################
# 3) MAIN RETRIEVAL PIPELINE
#############################################
def get_relevant_pubmed_docs(user_query: str) -> List[str]:
"""
End-to-end pipeline:
1. Fetches PubMed abstracts for the user query.
2. Indexes them in Chroma.
3. Retrieves the top relevant documents.
"""
new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
if not new_abstracts:
return []
index_pubmed_docs(new_abstracts, prefix=user_query)
top_docs = query_similar_docs(user_query, top_k=3)
return top_docs
|