mgbam commited on
Commit
ca7c5dc
·
verified ·
1 Parent(s): 113401c

Upload 4 files

Browse files
Files changed (4) hide show
  1. config.py +13 -0
  2. image_pipeline.py +16 -0
  3. models.py +48 -0
  4. pubmed_utils.py +96 -0
config.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # In a production Hugging Face Space, set these as Secrets
4
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
5
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY", "")
6
+ MY_PUBMED_EMAIL = os.getenv("MY_PUBMED_EMAIL", "[email protected]")
7
+
8
+ # Default LLM models
9
+ OPENAI_DEFAULT_MODEL = "gpt-3.5-turbo"
10
+ GEMINI_DEFAULT_MODEL = "models/chat-bison-001"
11
+
12
+ # Summarization chunk size
13
+ DEFAULT_CHUNK_SIZE = 512
image_pipeline.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ def load_image_model():
4
+ """
5
+ Loads HuggingFaceTB/SmolVLM-500M-Instruct or another image-to-text model.
6
+ """
7
+ return pipeline("image-to-text", model="HuggingFaceTB/SmolVLM-500M-Instruct")
8
+
9
+ def analyze_image(image_file, image_model):
10
+ """
11
+ Pass an image file to the image model pipeline and return the text/caption.
12
+ """
13
+ result = image_model(image_file)
14
+ if isinstance(result, list) and len(result) > 0:
15
+ return result[0].get("generated_text", "No caption generated.")
16
+ return "Unable to process image."
models.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import google.generativeai as genai
3
+ from config import OPENAI_API_KEY, GEMINI_API_KEY, OPENAI_DEFAULT_MODEL, GEMINI_DEFAULT_MODEL
4
+
5
+ def configure_llms():
6
+ """
7
+ Call this at startup or inside your main app file to configure
8
+ OpenAI and Gemini if keys are available.
9
+ """
10
+ if OPENAI_API_KEY:
11
+ openai.api_key = OPENAI_API_KEY
12
+ if GEMINI_API_KEY:
13
+ genai.configure(api_key=GEMINI_API_KEY)
14
+
15
+ def openai_chat(system_prompt, user_prompt, model=None, temperature=0.3):
16
+ """
17
+ Basic ChatCompletion with system + user roles for OpenAI.
18
+ """
19
+ if not OPENAI_API_KEY:
20
+ return "Error: OpenAI API key not provided."
21
+ chat_model = model or OPENAI_DEFAULT_MODEL
22
+ try:
23
+ response = openai.ChatCompletion.create(
24
+ model=chat_model,
25
+ messages=[
26
+ {"role": "system", "content": system_prompt},
27
+ {"role": "user", "content": user_prompt}
28
+ ],
29
+ temperature=temperature
30
+ )
31
+ return response.choices[0].message["content"].strip()
32
+ except Exception as e:
33
+ return f"Error calling OpenAI: {str(e)}"
34
+
35
+ def gemini_chat(system_prompt, user_prompt, model_name=None, temperature=0.3):
36
+ """
37
+ Basic call to Google PaLM2 via google.generativeai.
38
+ """
39
+ if not GEMINI_API_KEY:
40
+ return "Error: Gemini API key not provided."
41
+ final_model_name = model_name or GEMINI_DEFAULT_MODEL
42
+ try:
43
+ model = genai.GenerativeModel(model_name=final_model_name)
44
+ chat_session = model.start_chat(history=[("system", system_prompt)])
45
+ reply = chat_session.send_message(user_prompt, temperature=temperature)
46
+ return reply.text
47
+ except Exception as e:
48
+ return f"Error calling Gemini: {str(e)}"
pubmed_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from concurrent.futures import ThreadPoolExecutor, as_completed
3
+ import nltk
4
+ nltk.download('punkt')
5
+ from nltk.tokenize import sent_tokenize
6
+
7
+ from transformers import pipeline
8
+ from config import MY_PUBMED_EMAIL
9
+
10
+ # Build a summarization pipeline at module load (caching recommended)
11
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn", tokenizer="facebook/bart-large-cnn")
12
+
13
+ def search_pubmed(query, max_results=3):
14
+ """
15
+ Searches PubMed via ESearch and returns list of PMIDs.
16
+ """
17
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi"
18
+ params = {
19
+ "db": "pubmed",
20
+ "term": query,
21
+ "retmax": max_results,
22
+ "retmode": "json",
23
+ "tool": "ElysiumRAG",
24
+ "email": MY_PUBMED_EMAIL
25
+ }
26
+ resp = requests.get(base_url, params=params)
27
+ resp.raise_for_status()
28
+ data = resp.json()
29
+ return data.get("esearchresult", {}).get("idlist", [])
30
+
31
+ def fetch_one_abstract(pmid):
32
+ """
33
+ Fetches abstract for a given PMID. Returns (pmid, text).
34
+ """
35
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi"
36
+ params = {
37
+ "db": "pubmed",
38
+ "retmode": "text",
39
+ "rettype": "abstract",
40
+ "id": pmid,
41
+ "tool": "ElysiumRAG",
42
+ "email": MY_PUBMED_EMAIL
43
+ }
44
+ resp = requests.get(base_url, params=params)
45
+ resp.raise_for_status()
46
+ raw_text = resp.text.strip() or "No abstract text found."
47
+ return (pmid, raw_text)
48
+
49
+ def fetch_pubmed_abstracts(pmids):
50
+ """
51
+ Parallel retrieval of multiple PMIDs.
52
+ """
53
+ if not pmids:
54
+ return {}
55
+
56
+ results_map = {}
57
+ with ThreadPoolExecutor(max_workers=min(len(pmids), 5)) as executor:
58
+ future_to_pmid = {executor.submit(fetch_one_abstract, pmid): pmid for pmid in pmids}
59
+ for future in as_completed(future_to_pmid):
60
+ pmid = future_to_pmid[future]
61
+ try:
62
+ pmid_result, text = future.result()
63
+ results_map[pmid_result] = text
64
+ except Exception as e:
65
+ results_map[pmid] = f"Error: {str(e)}"
66
+ return results_map
67
+
68
+ def chunk_and_summarize(abstract_text, chunk_size=512):
69
+ """
70
+ Chunk large abstracts by sentence, summarize each chunk, then combine.
71
+ """
72
+ sentences = sent_tokenize(abstract_text)
73
+ chunks = []
74
+
75
+ current_chunk = []
76
+ current_length = 0
77
+ for sent in sentences:
78
+ tokens_in_sent = len(sent.split())
79
+ if current_length + tokens_in_sent > chunk_size:
80
+ chunks.append(" ".join(current_chunk))
81
+ current_chunk = []
82
+ current_length = 0
83
+ current_chunk.append(sent)
84
+ current_length += tokens_in_sent
85
+
86
+ if current_chunk:
87
+ chunks.append(" ".join(current_chunk))
88
+
89
+ summarized_pieces = []
90
+ for c in chunks:
91
+ summary_out = summarizer(
92
+ c, max_length=100, min_length=30, do_sample=False
93
+ )
94
+ summarized_pieces.append(summary_out[0]['summary_text'])
95
+
96
+ return " ".join(summarized_pieces).strip()