mgbam commited on
Commit
6f578ad
·
verified ·
1 Parent(s): 504e4af

Update pubmed_rag.py

Browse files
Files changed (1) hide show
  1. pubmed_rag.py +35 -145
pubmed_rag.py CHANGED
@@ -1,43 +1,18 @@
1
  import requests
2
- import nltk
3
- nltk.download("punkt")
4
  from nltk.tokenize import sent_tokenize
5
- from concurrent.futures import ThreadPoolExecutor, as_completed
6
-
7
- from transformers import pipeline, AutoTokenizer, AutoModel
8
- from sentence_transformers import SentenceTransformer
9
- import os
10
- import faiss
11
- import numpy as np
12
- import json
13
-
14
- from config import (
15
- PUBMED_EMAIL,
16
- MAX_PUBMED_RESULTS,
17
- DEFAULT_SUMMARIZATION_CHUNK,
18
- VECTORDB_PATH,
19
- EMBEDDING_MODEL_NAME
20
- )
21
-
22
- ###############################################################################
23
- # SUMMARIZATION & EMBEDDINGS #
24
- ###############################################################################
25
 
26
- summarizer = pipeline(
27
- "summarization",
28
- model="facebook/bart-large-cnn",
29
- tokenizer="facebook/bart-large-cnn",
30
- )
31
 
32
- embed_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
33
 
34
- ###############################################################################
35
- # PUBMED UTIL FUNCTIONS #
36
- ###############################################################################
37
 
38
  def search_pubmed(query, max_results=MAX_PUBMED_RESULTS):
39
  """
40
- Search PubMed for PMIDs matching a query. Returns a list of PMIDs.
41
  """
42
  url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
43
  params = {
@@ -46,16 +21,16 @@ def search_pubmed(query, max_results=MAX_PUBMED_RESULTS):
46
  "retmax": max_results,
47
  "retmode": "json",
48
  "tool": "AdvancedMedicalAI",
49
- "email": PUBMED_EMAIL
50
  }
51
- resp = requests.get(url, params=params)
52
- resp.raise_for_status()
53
- data = resp.json()
54
  return data.get("esearchresult", {}).get("idlist", [])
55
 
56
  def fetch_abstract(pmid):
57
  """
58
- Fetches an abstract for a single PMID via EFetch.
59
  """
60
  url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
61
  params = {
@@ -64,132 +39,47 @@ def fetch_abstract(pmid):
64
  "retmode": "text",
65
  "rettype": "abstract",
66
  "tool": "AdvancedMedicalAI",
67
- "email": PUBMED_EMAIL
68
  }
69
- resp = requests.get(url, params=params)
70
- resp.raise_for_status()
71
- return resp.text.strip()
72
 
73
  def fetch_pubmed_abstracts(pmids):
74
  """
75
- Parallel fetch for multiple PMIDs. Returns dict {pmid: text}.
76
  """
77
  results = {}
78
- with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
79
- future_to_pmid = {executor.submit(fetch_abstract, pmid): pmid for pmid in pmids}
80
- for future in as_completed(future_to_pmid):
81
- pmid = future_to_pmid[future]
82
- try:
83
- results[pmid] = future.result()
84
- except Exception as e:
85
- results[pmid] = f"Error fetching PMID {pmid}: {str(e)}"
86
  return results
87
 
88
- ###############################################################################
89
- # SUMMARIZE & CHUNK TEXT #
90
- ###############################################################################
91
-
92
- def chunk_and_summarize(raw_text, chunk_size=DEFAULT_SUMMARIZATION_CHUNK):
93
  """
94
- Splits large text into chunks by sentences, then summarizes each chunk, merging results.
95
  """
96
- sentences = sent_tokenize(raw_text)
97
  chunks = []
98
  current_chunk = []
99
  current_length = 0
100
 
101
- for sent in sentences:
102
- token_count = len(sent.split())
103
- if current_length + token_count > chunk_size:
104
  chunks.append(" ".join(current_chunk))
105
  current_chunk = []
106
  current_length = 0
107
- current_chunk.append(sent)
108
- current_length += token_count
109
 
110
  if current_chunk:
111
  chunks.append(" ".join(current_chunk))
112
 
113
- summary_list = []
114
- for c in chunks:
115
- summ = summarizer(c, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
116
- summary_list.append(summ)
117
- return " ".join(summary_list)
118
-
119
- ###############################################################################
120
- # SIMPLE VECTOR STORE (FAISS) FOR RAG #
121
- ###############################################################################
122
-
123
- def create_or_load_faiss_index():
124
- """
125
- Creates a new FAISS index or loads from disk if it exists.
126
- """
127
- index_path = os.path.join(VECTORDB_PATH, "faiss_index.bin")
128
- meta_path = os.path.join(VECTORDB_PATH, "faiss_meta.json")
129
-
130
- if not os.path.exists(VECTORDB_PATH):
131
- os.makedirs(VECTORDB_PATH)
132
-
133
- if os.path.exists(index_path) and os.path.exists(meta_path):
134
- # Load existing index
135
- index = faiss.read_index(index_path)
136
- with open(meta_path, "r") as f:
137
- meta_data = json.load(f)
138
- return index, meta_data
139
- else:
140
- # Create new index
141
- index = faiss.IndexFlatL2(embed_model.get_sentence_embedding_dimension())
142
- meta_data = {}
143
- return index, meta_data
144
-
145
- def save_faiss_index(index, meta_data):
146
- """
147
- Saves the FAISS index and metadata to disk.
148
- """
149
- index_path = os.path.join(VECTORDB_PATH, "faiss_index.bin")
150
- meta_path = os.path.join(VECTORDB_PATH, "faiss_meta.json")
151
-
152
- faiss.write_index(index, index_path)
153
- with open(meta_path, "w") as f:
154
- json.dump(meta_data, f)
155
-
156
- def upsert_documents(docs):
157
- """
158
- Takes in a dict of {pmid: text}, embeds and upserts them into the FAISS index.
159
- Each doc is stored in 'meta_data' with pmid as key.
160
- """
161
- index, meta_data = create_or_load_faiss_index()
162
-
163
- texts = list(docs.values())
164
- pmids = list(docs.keys())
165
-
166
- embeddings = embed_model.encode(texts, convert_to_numpy=True)
167
- index.add(embeddings)
168
-
169
- # Maintain a simple meta_data: { int_id: { 'pmid': X, 'text': Y } }
170
- # Where int_id is the row in the index
171
- start_id = len(meta_data)
172
- for i, pmid in enumerate(pmids):
173
- meta_data[str(start_id + i)] = {"pmid": pmid, "text": texts[i]}
174
-
175
- save_faiss_index(index, meta_data)
176
-
177
- def semantic_search(query, top_k=3):
178
- """
179
- Embeds 'query' and searches the FAISS index for top_k similar docs.
180
- Returns a list of dict with 'pmid' and 'text'.
181
- """
182
- index, meta_data = create_or_load_faiss_index()
183
-
184
- query_embedding = embed_model.encode([query], convert_to_numpy=True)
185
- distances, indices = index.search(query_embedding, top_k)
186
-
187
- results = []
188
- for dist, idx_list in zip(distances, indices):
189
- for d, i in zip(dist, idx_list):
190
- # i is row in the index, look up meta_data
191
- doc_info = meta_data[str(i)]
192
- results.append({"pmid": doc_info["pmid"], "text": doc_info["text"], "score": float(d)})
193
- # Sort by ascending distance => best match first
194
- results.sort(key=lambda x: x["score"])
195
- return results
 
1
  import requests
2
+ from transformers import pipeline
 
3
  from nltk.tokenize import sent_tokenize
4
+ import nltk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ from config import MY_PUBMED_EMAIL, MAX_PUBMED_RESULTS, SUMMARIZATION_CHUNK_SIZE
 
 
 
 
7
 
8
+ nltk.download("punkt")
9
 
10
+ # Summarization pipeline
11
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
 
12
 
13
  def search_pubmed(query, max_results=MAX_PUBMED_RESULTS):
14
  """
15
+ Search PubMed for articles matching the query.
16
  """
17
  url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
18
  params = {
 
21
  "retmax": max_results,
22
  "retmode": "json",
23
  "tool": "AdvancedMedicalAI",
24
+ "email": MY_PUBMED_EMAIL,
25
  }
26
+ response = requests.get(url, params=params)
27
+ response.raise_for_status()
28
+ data = response.json()
29
  return data.get("esearchresult", {}).get("idlist", [])
30
 
31
  def fetch_abstract(pmid):
32
  """
33
+ Fetch the abstract of a given PubMed ID.
34
  """
35
  url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
36
  params = {
 
39
  "retmode": "text",
40
  "rettype": "abstract",
41
  "tool": "AdvancedMedicalAI",
42
+ "email": MY_PUBMED_EMAIL,
43
  }
44
+ response = requests.get(url, params=params)
45
+ response.raise_for_status()
46
+ return response.text.strip()
47
 
48
  def fetch_pubmed_abstracts(pmids):
49
  """
50
+ Fetch multiple abstracts for a list of PMIDs.
51
  """
52
  results = {}
53
+ for pmid in pmids:
54
+ try:
55
+ results[pmid] = fetch_abstract(pmid)
56
+ except Exception as e:
57
+ results[pmid] = f"Error fetching PMID {pmid}: {e}"
 
 
 
58
  return results
59
 
60
+ def summarize_text(text, chunk_size=SUMMARIZATION_CHUNK_SIZE):
 
 
 
 
61
  """
62
+ Summarize long text using a chunking strategy.
63
  """
64
+ sentences = sent_tokenize(text)
65
  chunks = []
66
  current_chunk = []
67
  current_length = 0
68
 
69
+ for sentence in sentences:
70
+ tokens = len(sentence.split())
71
+ if current_length + tokens > chunk_size:
72
  chunks.append(" ".join(current_chunk))
73
  current_chunk = []
74
  current_length = 0
75
+ current_chunk.append(sentence)
76
+ current_length += tokens
77
 
78
  if current_chunk:
79
  chunks.append(" ".join(current_chunk))
80
 
81
+ summaries = [
82
+ summarizer(chunk, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
83
+ for chunk in chunks
84
+ ]
85
+ return " ".join(summaries)