File size: 4,697 Bytes
7a907a1
 
 
 
 
 
 
 
fbe007f
7a907a1
 
 
 
 
 
 
fbe007f
7a907a1
 
 
 
 
 
 
 
 
 
 
 
 
6960eaf
7a907a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cc52cd
7a907a1
 
 
664ce80
7a907a1
fbe007f
830f001
6960eaf
7a907a1
664ce80
 
 
 
 
0cc52cd
 
 
c8255aa
0cc52cd
 
 
6960eaf
fbe007f
 
 
 
0cc52cd
 
664ce80
6960eaf
c8255aa
 
fbe007f
0cc52cd
664ce80
c8255aa
664ce80
7a907a1
0cc52cd
1f0b244
 
2642a9d
 
1f0b244
 
2642a9d
830f001
664ce80
 
 
 
7a907a1
 
 
fbe007f
7a907a1
 
0cc52cd
 
 
7a907a1
 
 
fbe007f
7a907a1
 
 
 
 
2642a9d
7a907a1
 
 
6960eaf
fbe007f
 
 
7a907a1
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import requests
import os
from chromadb.config import Settings
import chromadb
from transformers import AutoTokenizer, AutoModel
import torch
from typing import List

# Optional: Set your PubMed API key as an environment variable
PUBMED_API_KEY = os.environ.get("PUBMED_API_KEY", "<YOUR_NCBI_API_KEY>")

#############################################
# 1) FETCHING ABSTRACTS FROM PUBMED
#############################################
def fetch_pubmed_abstracts(query: str, max_results: int = 5) -> List[str]:
    """
    Fetches PubMed abstracts for the specified query using NCBI's E-utilities.
    Returns a list of abstract texts.
    """
    search_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
    params = {
        "db": "pubmed",
        "term": query,
        "retmax": max_results,
        "api_key": PUBMED_API_KEY,
        "retmode": "json"
    }
    r = requests.get(search_url, params=params)
    r.raise_for_status()
    data = r.json()

    pmid_list = data["esearchresult"].get("idlist", [])
    abstracts = []
    fetch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
    for pmid in pmid_list:
        fetch_params = {
            "db": "pubmed",
            "id": pmid,
            "rettype": "abstract",
            "retmode": "text",
            "api_key": PUBMED_API_KEY
        }
        fetch_resp = requests.get(fetch_url, params=fetch_params)
        fetch_resp.raise_for_status()
        abstract_text = fetch_resp.text.strip()
        if abstract_text:
            abstracts.append(abstract_text)
    return abstracts

#############################################
# 2) CHROMA VECTOR STORE SETUP
#############################################
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

class EmbedFunction:
    """
    Wraps a Hugging Face embedding model and provides a __call__ method with the signature:
      (self, input: List[str]) -> List[List[float]]
    This is required by the latest Chroma embedding interface.
    """
    def __init__(self, model_name: str):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.eval()
    
    def __call__(self, input: List[str]) -> List[List[float]]:
        """
        Batch-embeds a list of strings.
        Returns one embedding per input string as a list of floats.
        """
        if not input:
            return []
        tokenized = self.tokenizer(
            input,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )
        with torch.no_grad():
            outputs = self.model(**tokenized, output_hidden_states=True)
        last_hidden = outputs.hidden_states[-1]  # shape: [batch_size, seq_len, hidden_dim]
        pooled = last_hidden.mean(dim=1)           # shape: [batch_size, hidden_dim]
        embeddings = pooled.cpu().tolist()           # List[List[float]]
        return embeddings

# Instantiate the embedding function (using the correct parameter name "input")
embed_function = EmbedFunction(EMBED_MODEL_NAME)

# Create the Chroma client with updated settings
client = chromadb.Client(
    settings=Settings(
        persist_directory="chromadb_data",
        anonymized_telemetry=False
    )
)

# Create or get the collection using the updated embed_function
collection = client.get_or_create_collection(
    name="pubmed_abstracts",
    embedding_function=embed_function
)

def index_pubmed_docs(docs: List[str], prefix: str = "doc"):
    """
    Adds documents to Chroma with unique IDs.
    """
    for i, doc in enumerate(docs):
        if doc.strip():
            doc_id = f"{prefix}-{i}"
            collection.add(documents=[doc], ids=[doc_id])

def query_similar_docs(query: str, top_k: int = 3) -> List[str]:
    """
    Retrieves the top_k similar documents from Chroma based on embedding distance.
    """
    results = collection.query(query_texts=[query], n_results=top_k)
    return results["documents"][0] if results and results["documents"] else []

#############################################
# 3) MAIN RETRIEVAL PIPELINE
#############################################
def get_relevant_pubmed_docs(user_query: str) -> List[str]:
    """
    End-to-end pipeline:
      1. Fetches PubMed abstracts for the user query.
      2. Indexes them in Chroma.
      3. Retrieves the top relevant documents.
    """
    new_abstracts = fetch_pubmed_abstracts(user_query, max_results=5)
    if not new_abstracts:
        return []
    index_pubmed_docs(new_abstracts, prefix=user_query)
    top_docs = query_similar_docs(user_query, top_k=3)
    return top_docs