mgbam commited on
Commit
ff77b73
·
verified ·
1 Parent(s): 834aa30

Rename pubmed_utils.py to pubmed_rag.py

Browse files
Files changed (2) hide show
  1. pubmed_rag.py +195 -0
  2. pubmed_utils.py +0 -84
pubmed_rag.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import nltk
3
+ nltk.download("punkt")
4
+ from nltk.tokenize import sent_tokenize
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+
7
+ from transformers import pipeline, AutoTokenizer, AutoModel
8
+ from sentence_transformers import SentenceTransformer
9
+ import os
10
+ import faiss
11
+ import numpy as np
12
+ import json
13
+
14
+ from config import (
15
+ PUBMED_EMAIL,
16
+ MAX_PUBMED_RESULTS,
17
+ DEFAULT_SUMMARIZATION_CHUNK,
18
+ VECTORDB_PATH,
19
+ EMBEDDING_MODEL_NAME
20
+ )
21
+
22
+ ###############################################################################
23
+ # SUMMARIZATION & EMBEDDINGS #
24
+ ###############################################################################
25
+
26
+ summarizer = pipeline(
27
+ "summarization",
28
+ model="facebook/bart-large-cnn",
29
+ tokenizer="facebook/bart-large-cnn",
30
+ )
31
+
32
+ embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
33
+
34
+ ###############################################################################
35
+ # PUBMED UTIL FUNCTIONS #
36
+ ###############################################################################
37
+
38
+ def search_pubmed(query, max_results=MAX_PUBMED_RESULTS):
39
+ """
40
+ Search PubMed for PMIDs matching a query. Returns a list of PMIDs.
41
+ """
42
+ url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
43
+ params = {
44
+ "db": "pubmed",
45
+ "term": query,
46
+ "retmax": max_results,
47
+ "retmode": "json",
48
+ "tool": "AdvancedMedicalAI",
49
+ "email": PUBMED_EMAIL
50
+ }
51
+ resp = requests.get(url, params=params)
52
+ resp.raise_for_status()
53
+ data = resp.json()
54
+ return data.get("esearchresult", {}).get("idlist", [])
55
+
56
+ def fetch_abstract(pmid):
57
+ """
58
+ Fetches an abstract for a single PMID via EFetch.
59
+ """
60
+ url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
61
+ params = {
62
+ "db": "pubmed",
63
+ "id": pmid,
64
+ "retmode": "text",
65
+ "rettype": "abstract",
66
+ "tool": "AdvancedMedicalAI",
67
+ "email": PUBMED_EMAIL
68
+ }
69
+ resp = requests.get(url, params=params)
70
+ resp.raise_for_status()
71
+ return resp.text.strip()
72
+
73
+ def fetch_pubmed_abstracts(pmids):
74
+ """
75
+ Parallel fetch for multiple PMIDs. Returns dict {pmid: text}.
76
+ """
77
+ results = {}
78
+ with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
79
+ future_to_pmid = {executor.submit(fetch_abstract, pmid): pmid for pmid in pmids}
80
+ for future in as_completed(future_to_pmid):
81
+ pmid = future_to_pmid[future]
82
+ try:
83
+ results[pmid] = future.result()
84
+ except Exception as e:
85
+ results[pmid] = f"Error fetching PMID {pmid}: {str(e)}"
86
+ return results
87
+
88
+ ###############################################################################
89
+ # SUMMARIZE & CHUNK TEXT #
90
+ ###############################################################################
91
+
92
+ def chunk_and_summarize(raw_text, chunk_size=DEFAULT_SUMMARIZATION_CHUNK):
93
+ """
94
+ Splits large text into chunks by sentences, then summarizes each chunk, merging results.
95
+ """
96
+ sentences = sent_tokenize(raw_text)
97
+ chunks = []
98
+ current_chunk = []
99
+ current_length = 0
100
+
101
+ for sent in sentences:
102
+ token_count = len(sent.split())
103
+ if current_length + token_count > chunk_size:
104
+ chunks.append(" ".join(current_chunk))
105
+ current_chunk = []
106
+ current_length = 0
107
+ current_chunk.append(sent)
108
+ current_length += token_count
109
+
110
+ if current_chunk:
111
+ chunks.append(" ".join(current_chunk))
112
+
113
+ summary_list = []
114
+ for c in chunks:
115
+ summ = summarizer(c, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
116
+ summary_list.append(summ)
117
+ return " ".join(summary_list)
118
+
119
+ ###############################################################################
120
+ # SIMPLE VECTOR STORE (FAISS) FOR RAG #
121
+ ###############################################################################
122
+
123
+ def create_or_load_faiss_index():
124
+ """
125
+ Creates a new FAISS index or loads from disk if it exists.
126
+ """
127
+ index_path = os.path.join(VECTORDB_PATH, "faiss_index.bin")
128
+ meta_path = os.path.join(VECTORDB_PATH, "faiss_meta.json")
129
+
130
+ if not os.path.exists(VECTORDB_PATH):
131
+ os.makedirs(VECTORDB_PATH)
132
+
133
+ if os.path.exists(index_path) and os.path.exists(meta_path):
134
+ # Load existing index
135
+ index = faiss.read_index(index_path)
136
+ with open(meta_path, "r") as f:
137
+ meta_data = json.load(f)
138
+ return index, meta_data
139
+ else:
140
+ # Create new index
141
+ index = faiss.IndexFlatL2(embed_model.get_sentence_embedding_dimension())
142
+ meta_data = {}
143
+ return index, meta_data
144
+
145
+ def save_faiss_index(index, meta_data):
146
+ """
147
+ Saves the FAISS index and metadata to disk.
148
+ """
149
+ index_path = os.path.join(VECTORDB_PATH, "faiss_index.bin")
150
+ meta_path = os.path.join(VECTORDB_PATH, "faiss_meta.json")
151
+
152
+ faiss.write_index(index, index_path)
153
+ with open(meta_path, "w") as f:
154
+ json.dump(meta_data, f)
155
+
156
+ def upsert_documents(docs):
157
+ """
158
+ Takes in a dict of {pmid: text}, embeds and upserts them into the FAISS index.
159
+ Each doc is stored in 'meta_data' with pmid as key.
160
+ """
161
+ index, meta_data = create_or_load_faiss_index()
162
+
163
+ texts = list(docs.values())
164
+ pmids = list(docs.keys())
165
+
166
+ embeddings = embed_model.encode(texts, convert_to_numpy=True)
167
+ index.add(embeddings)
168
+
169
+ # Maintain a simple meta_data: { int_id: { 'pmid': X, 'text': Y } }
170
+ # Where int_id is the row in the index
171
+ start_id = len(meta_data)
172
+ for i, pmid in enumerate(pmids):
173
+ meta_data[str(start_id + i)] = {"pmid": pmid, "text": texts[i]}
174
+
175
+ save_faiss_index(index, meta_data)
176
+
177
+ def semantic_search(query, top_k=3):
178
+ """
179
+ Embeds 'query' and searches the FAISS index for top_k similar docs.
180
+ Returns a list of dict with 'pmid' and 'text'.
181
+ """
182
+ index, meta_data = create_or_load_faiss_index()
183
+
184
+ query_embedding = embed_model.encode([query], convert_to_numpy=True)
185
+ distances, indices = index.search(query_embedding, top_k)
186
+
187
+ results = []
188
+ for dist, idx_list in zip(distances, indices):
189
+ for d, i in zip(dist, idx_list):
190
+ # i is row in the index, look up meta_data
191
+ doc_info = meta_data[str(i)]
192
+ results.append({"pmid": doc_info["pmid"], "text": doc_info["text"], "score": float(d)})
193
+ # Sort by ascending distance => best match first
194
+ results.sort(key=lambda x: x["score"])
195
+ return results
pubmed_utils.py DELETED
@@ -1,84 +0,0 @@
1
- import requests
2
- from concurrent.futures import ThreadPoolExecutor, as_completed
3
- from transformers import pipeline
4
- from config import PUBMED_EMAIL, CHUNK_SIZE
5
-
6
- # Summarization pipeline
7
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
8
-
9
-
10
- def search_pubmed(query, max_results=5):
11
- """
12
- Search PubMed for PMIDs matching the query.
13
- """
14
- url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
15
- params = {
16
- "db": "pubmed",
17
- "term": query,
18
- "retmax": max_results,
19
- "retmode": "json",
20
- "tool": "MedicalAI",
21
- "email": PUBMED_EMAIL,
22
- }
23
- response = requests.get(url, params=params)
24
- response.raise_for_status()
25
- return response.json().get("esearchresult", {}).get("idlist", [])
26
-
27
-
28
- def fetch_abstract(pmid):
29
- """
30
- Fetch abstract for a given PubMed ID.
31
- """
32
- url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
33
- params = {
34
- "db": "pubmed",
35
- "id": pmid,
36
- "retmode": "text",
37
- "rettype": "abstract",
38
- "tool": "MedicalAI",
39
- "email": PUBMED_EMAIL,
40
- }
41
- response = requests.get(url, params=params)
42
- response.raise_for_status()
43
- return response.text.strip()
44
-
45
-
46
- def fetch_pubmed_abstracts(pmids):
47
- """
48
- Fetch multiple PubMed abstracts concurrently.
49
- """
50
- results = {}
51
- with ThreadPoolExecutor(max_workers=5) as executor:
52
- future_to_pmid = {executor.submit(fetch_abstract, pmid): pmid for pmid in pmids}
53
- for future in as_completed(future_to_pmid):
54
- pmid = future_to_pmid[future]
55
- try:
56
- results[pmid] = future.result()
57
- except Exception as e:
58
- results[pmid] = f"Error fetching PMID {pmid}: {str(e)}"
59
- return results
60
-
61
-
62
- def summarize_text(text, chunk_size=CHUNK_SIZE):
63
- """
64
- Summarize long text using a chunking strategy.
65
- """
66
- sentences = text.split(". ")
67
- chunks = []
68
- current_chunk = []
69
- current_length = 0
70
-
71
- for sentence in sentences:
72
- tokens = len(sentence.split())
73
- if current_length + tokens > chunk_size:
74
- chunks.append(" ".join(current_chunk))
75
- current_chunk = []
76
- current_length = 0
77
- current_chunk.append(sentence)
78
- current_length += tokens
79
-
80
- if current_chunk:
81
- chunks.append(" ".join(current_chunk))
82
-
83
- summaries = [summarizer(chunk, max_length=100, min_length=30)[0]["summary_text"] for chunk in chunks]
84
- return " ".join(summaries)