Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,8 @@ import tempfile
|
|
14 |
from astroquery.nasa_ads import ADS
|
15 |
import pyvo as vo
|
16 |
import pandas as pd
|
|
|
|
|
17 |
|
18 |
# Load the NASA-specific bi-encoder model and tokenizer
|
19 |
bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
|
@@ -71,54 +73,105 @@ def encode_text(text):
|
|
71 |
outputs = bi_model(**inputs)
|
72 |
return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()
|
73 |
|
74 |
-
def get_chunks(text, chunk_size=
|
75 |
"""
|
76 |
-
|
|
|
77 |
"""
|
78 |
if not text.strip():
|
79 |
-
raise ValueError("The provided
|
80 |
-
|
81 |
# Split the text into chunks of approximately 'chunk_size' characters
|
82 |
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
|
|
83 |
return chunks
|
84 |
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
"""
|
87 |
-
|
88 |
-
|
89 |
-
If no chunk meets the similarity threshold, return a fallback message.
|
90 |
"""
|
91 |
-
#
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
# Split the long context text into chunks using the chunking function
|
96 |
-
context_chunks = get_chunks(context_texts, chunk_size)
|
97 |
-
|
98 |
-
# Handle single context case
|
99 |
-
if len(context_chunks) == 1:
|
100 |
-
return context_chunks[0], 1.0 # Return the single chunk with perfect similarity
|
101 |
-
|
102 |
-
# Encode the user input to create a query embedding
|
103 |
-
user_embedding = encode_text(user_input).reshape(1, -1)
|
104 |
-
|
105 |
-
# Encode all context chunks to create embeddings
|
106 |
-
chunk_embeddings = np.array([encode_text(chunk) for chunk in context_chunks])
|
107 |
-
|
108 |
-
# Compute cosine similarity between the user input and each chunk
|
109 |
-
similarities = cosine_similarity(user_embedding, chunk_embeddings).flatten()
|
110 |
-
|
111 |
-
# Check if any similarity scores are above the threshold
|
112 |
-
if max(similarities) < similarity_threshold:
|
113 |
-
return "No relevant context found for the user input.", None
|
114 |
-
|
115 |
-
# Identify the most relevant chunk based on the highest cosine similarity score
|
116 |
-
most_relevant_idx = np.argmax(similarities)
|
117 |
-
most_relevant_chunk = context_chunks[most_relevant_idx]
|
118 |
|
119 |
-
#
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
def extract_keywords_with_gpt(user_input, max_tokens=100, temperature=0.3):
|
124 |
# Define a prompt to ask GPT-4 to extract keywords and important terms
|
@@ -367,20 +420,21 @@ def gpt_response_to_dataframe(gpt_response):
|
|
367 |
return df
|
368 |
|
369 |
def chatbot(user_input, science_objectives="", context="", subdomain="", use_encoder=False, max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
|
|
375 |
|
376 |
# Fetch NASA ADS references using the full prompt
|
377 |
references = fetch_nasa_ads_references(subdomain)
|
378 |
|
379 |
-
# Generate response from GPT-4
|
380 |
response = generate_response(
|
381 |
user_input=user_input,
|
382 |
-
science_objectives=science_objectives,
|
383 |
-
relevant_context=relevant_context, #
|
384 |
references=references,
|
385 |
max_tokens=max_tokens,
|
386 |
temperature=temperature,
|
@@ -389,6 +443,7 @@ def chatbot(user_input, science_objectives="", context="", subdomain="", use_enc
|
|
389 |
presence_penalty=presence_penalty
|
390 |
)
|
391 |
|
|
|
392 |
if science_objectives.strip():
|
393 |
response = f"### Science Objectives (User-Defined):\n\n{science_objectives}\n\n" + response
|
394 |
|
@@ -443,13 +498,16 @@ def chatbot(user_input, science_objectives="", context="", subdomain="", use_enc
|
|
443 |
return full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html
|
444 |
|
445 |
with gr.Blocks() as demo:
|
446 |
-
gr.Markdown("# ExosAI - NASA SMD SCDD Generator with RAG [version-1.
|
447 |
-
|
448 |
# User Inputs
|
449 |
user_input = gr.Textbox(lines=5, placeholder="Enter your Science Goal...", label="Science Goal")
|
450 |
context = gr.Textbox(lines=10, placeholder="Enter Context Text...", label="Context")
|
451 |
subdomain = gr.Textbox(lines=2, placeholder="Define your Subdomain...", label="Subdomain Definition")
|
452 |
|
|
|
|
|
|
|
453 |
# Science Objectives Button & Input (Initially Hidden)
|
454 |
science_objectives_button = gr.Button("Manually Enter Science Objectives")
|
455 |
science_objectives_input = gr.Textbox(
|
@@ -459,15 +517,14 @@ with gr.Blocks() as demo:
|
|
459 |
visible=False # Initially hidden
|
460 |
)
|
461 |
|
462 |
-
#
|
463 |
science_objectives_button.click(
|
464 |
fn=lambda: gr.update(visible=True), # Show textbox when clicked
|
465 |
inputs=[],
|
466 |
outputs=[science_objectives_input]
|
467 |
)
|
468 |
|
469 |
-
#
|
470 |
-
use_encoder = gr.Checkbox(label="Use NASA SMD Bi-Encoder for Context")
|
471 |
max_tokens = gr.Slider(50, 2000, value=150, step=10, label="Max Tokens")
|
472 |
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature")
|
473 |
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.1, label="Top-p")
|
@@ -490,8 +547,8 @@ with gr.Blocks() as demo:
|
|
490 |
submit_button.click(
|
491 |
fn=chatbot,
|
492 |
inputs=[
|
493 |
-
user_input, science_objectives_input, context, subdomain,
|
494 |
-
|
495 |
],
|
496 |
outputs=[full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html]
|
497 |
)
|
@@ -503,7 +560,7 @@ with gr.Blocks() as demo:
|
|
503 |
"", # science_objectives_input
|
504 |
"", # context
|
505 |
"", # subdomain
|
506 |
-
|
507 |
150, # max_tokens
|
508 |
0.7, # temperature
|
509 |
0.9, # top_p
|
@@ -521,11 +578,12 @@ with gr.Blocks() as demo:
|
|
521 |
fn=clear_all,
|
522 |
inputs=[],
|
523 |
outputs=[
|
524 |
-
user_input, science_objectives_input, context, subdomain,
|
525 |
-
|
526 |
full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html
|
527 |
]
|
528 |
)
|
529 |
|
530 |
# Launch the app
|
531 |
demo.launch(share=True)
|
|
|
|
14 |
from astroquery.nasa_ads import ADS
|
15 |
import pyvo as vo
|
16 |
import pandas as pd
|
17 |
+
import faiss
|
18 |
+
from PyPDF2 import PdfReader
|
19 |
|
20 |
# Load the NASA-specific bi-encoder model and tokenizer
|
21 |
bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
|
|
|
73 |
outputs = bi_model(**inputs)
|
74 |
return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()
|
75 |
|
76 |
+
def get_chunks(text, chunk_size=500):
|
77 |
"""
|
78 |
+
Splits a long text into smaller chunks of approximately 'chunk_size' characters.
|
79 |
+
Ensures that chunks do not cut off words abruptly.
|
80 |
"""
|
81 |
if not text.strip():
|
82 |
+
raise ValueError("The provided text is empty or blank.")
|
83 |
+
|
84 |
# Split the text into chunks of approximately 'chunk_size' characters
|
85 |
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
86 |
+
|
87 |
return chunks
|
88 |
|
89 |
+
# Initialize FAISS index with cosine similarity
|
90 |
+
|
91 |
+
embedding_dim = 768 # NASA Bi-Encoder outputs 768-dimensional embeddings
|
92 |
+
index = faiss.IndexFlatIP(embedding_dim) # FAISS inner product (cosine similarity)
|
93 |
+
pdf_chunks = [] # Store extracted chunks for later reference
|
94 |
+
chunk_embeddings = [] # Store embeddings for similarity retrieval
|
95 |
+
|
96 |
+
def load_and_process_uploaded_pdfs(pdf_files):
|
97 |
+
"""Extracts text from uploaded PDFs, splits into chunks, generates embeddings, and stores in FAISS."""
|
98 |
+
global index, pdf_chunks, chunk_embeddings
|
99 |
+
|
100 |
+
# Reset the FAISS index and stored data for new uploads
|
101 |
+
index.reset()
|
102 |
+
pdf_chunks.clear()
|
103 |
+
chunk_embeddings.clear()
|
104 |
+
|
105 |
+
text_data = []
|
106 |
+
|
107 |
+
for pdf_file in pdf_files:
|
108 |
+
if pdf_file is None:
|
109 |
+
continue # Skip if no file is uploaded
|
110 |
+
|
111 |
+
reader = PdfReader(pdf_file)
|
112 |
+
pdf_text = ""
|
113 |
+
for page in reader.pages:
|
114 |
+
pdf_text += page.extract_text() + "\n"
|
115 |
+
|
116 |
+
# Split extracted text into chunks
|
117 |
+
chunks = get_chunks(pdf_text, chunk_size=500) # Adjust chunk size if needed
|
118 |
+
pdf_chunks.extend(chunks) # Store for retrieval
|
119 |
+
|
120 |
+
# Generate embeddings for each chunk and store in FAISS
|
121 |
+
for chunk in chunks:
|
122 |
+
chunk_embedding = encode_text(chunk).reshape(1, -1)
|
123 |
+
|
124 |
+
# Normalize the embedding for cosine similarity
|
125 |
+
chunk_embedding = chunk_embedding / np.linalg.norm(chunk_embedding)
|
126 |
+
|
127 |
+
index.add(chunk_embedding) # Add normalized embeddings to FAISS
|
128 |
+
chunk_embeddings.append(chunk_embedding) # Store for reference
|
129 |
+
|
130 |
+
text_data.extend(chunks)
|
131 |
+
|
132 |
+
return text_data
|
133 |
+
|
134 |
+
|
135 |
+
def retrieve_relevant_context(user_input, context_text, science_objectives="", k=3):
|
136 |
"""
|
137 |
+
Retrieve the most relevant document chunks using cosine similarity search.
|
138 |
+
Uses combined user inputs (Science Goal + Context + Optional Science Objectives).
|
|
|
139 |
"""
|
140 |
+
global chunk_embeddings, pdf_chunks # Ensure we're using the globally stored embeddings and text chunks
|
141 |
+
|
142 |
+
# Combine all user inputs into a single query
|
143 |
+
query_text = f"Science Goal: {user_input}\nContext: {context_text}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
+
# Append Science Objectives only if provided
|
146 |
+
if science_objectives.strip():
|
147 |
+
query_text += f"\nScience Objectives: {science_objectives}"
|
148 |
+
|
149 |
+
# Generate query embedding
|
150 |
+
query_embedding = encode_text(query_text).reshape(1, -1)
|
151 |
+
|
152 |
+
# Normalize the query embedding for cosine similarity
|
153 |
+
query_embedding = query_embedding / np.linalg.norm(query_embedding)
|
154 |
+
|
155 |
+
# Convert stored chunk embeddings into a NumPy array
|
156 |
+
if len(chunk_embeddings) == 0:
|
157 |
+
return "No preloaded document data available.", None
|
158 |
+
|
159 |
+
chunk_embeddings_array = np.array(chunk_embeddings).reshape(len(chunk_embeddings), -1)
|
160 |
+
|
161 |
+
# Compute cosine similarity between the query and all stored chunk embeddings
|
162 |
+
similarities = cosine_similarity(query_embedding, chunk_embeddings_array).flatten()
|
163 |
+
|
164 |
+
# Get indices of top k most relevant chunks
|
165 |
+
top_indices = similarities.argsort()[-k:][::-1] # Sort in descending order
|
166 |
|
167 |
+
# Retrieve the most relevant chunks
|
168 |
+
retrieved_context = "\n\n".join([pdf_chunks[i] for i in top_indices])
|
169 |
+
|
170 |
+
# If no relevant chunk is found, return a default message
|
171 |
+
if not retrieved_context.strip():
|
172 |
+
return "No relevant context found for the query."
|
173 |
+
|
174 |
+
return retrieved_context
|
175 |
|
176 |
def extract_keywords_with_gpt(user_input, max_tokens=100, temperature=0.3):
|
177 |
# Define a prompt to ask GPT-4 to extract keywords and important terms
|
|
|
420 |
return df
|
421 |
|
422 |
def chatbot(user_input, science_objectives="", context="", subdomain="", use_encoder=False, max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
|
423 |
+
"""
|
424 |
+
Handles the full workflow: retrieves relevant context, generates response, processes output.
|
425 |
+
"""
|
426 |
+
|
427 |
+
# Retrieve relevant context from FAISS using all user inputs
|
428 |
+
relevant_context = retrieve_relevant_context(user_input, context, science_objectives)
|
429 |
|
430 |
# Fetch NASA ADS references using the full prompt
|
431 |
references = fetch_nasa_ads_references(subdomain)
|
432 |
|
433 |
+
# Generate response from GPT-4, ensuring we pass all relevant inputs
|
434 |
response = generate_response(
|
435 |
user_input=user_input,
|
436 |
+
science_objectives=science_objectives,
|
437 |
+
relevant_context=relevant_context, # Ensure retrieved FAISS context is passed
|
438 |
references=references,
|
439 |
max_tokens=max_tokens,
|
440 |
temperature=temperature,
|
|
|
443 |
presence_penalty=presence_penalty
|
444 |
)
|
445 |
|
446 |
+
# Append manually entered science objectives to the response (if provided)
|
447 |
if science_objectives.strip():
|
448 |
response = f"### Science Objectives (User-Defined):\n\n{science_objectives}\n\n" + response
|
449 |
|
|
|
498 |
return full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html
|
499 |
|
500 |
with gr.Blocks() as demo:
|
501 |
+
gr.Markdown("# ExosAI - NASA SMD SCDD Generator with RAG [version-1.2]")
|
502 |
+
|
503 |
# User Inputs
|
504 |
user_input = gr.Textbox(lines=5, placeholder="Enter your Science Goal...", label="Science Goal")
|
505 |
context = gr.Textbox(lines=10, placeholder="Enter Context Text...", label="Context")
|
506 |
subdomain = gr.Textbox(lines=2, placeholder="Define your Subdomain...", label="Subdomain Definition")
|
507 |
|
508 |
+
# PDF Upload Section (Up to 3 PDFs)
|
509 |
+
uploaded_pdfs = gr.File(file_types=[".pdf"], label="Upload Reference PDFs (Up to 3)", interactive=True, multiple=True)
|
510 |
+
|
511 |
# Science Objectives Button & Input (Initially Hidden)
|
512 |
science_objectives_button = gr.Button("Manually Enter Science Objectives")
|
513 |
science_objectives_input = gr.Textbox(
|
|
|
517 |
visible=False # Initially hidden
|
518 |
)
|
519 |
|
520 |
+
# Event to Show Science Objectives Input
|
521 |
science_objectives_button.click(
|
522 |
fn=lambda: gr.update(visible=True), # Show textbox when clicked
|
523 |
inputs=[],
|
524 |
outputs=[science_objectives_input]
|
525 |
)
|
526 |
|
527 |
+
# Additional Model Parameters
|
|
|
528 |
max_tokens = gr.Slider(50, 2000, value=150, step=10, label="Max Tokens")
|
529 |
temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature")
|
530 |
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.1, label="Top-p")
|
|
|
547 |
submit_button.click(
|
548 |
fn=chatbot,
|
549 |
inputs=[
|
550 |
+
user_input, science_objectives_input, context, subdomain, uploaded_pdfs,
|
551 |
+
max_tokens, temperature, top_p, frequency_penalty, presence_penalty
|
552 |
],
|
553 |
outputs=[full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html]
|
554 |
)
|
|
|
560 |
"", # science_objectives_input
|
561 |
"", # context
|
562 |
"", # subdomain
|
563 |
+
None, # uploaded_pdfs
|
564 |
150, # max_tokens
|
565 |
0.7, # temperature
|
566 |
0.9, # top_p
|
|
|
578 |
fn=clear_all,
|
579 |
inputs=[],
|
580 |
outputs=[
|
581 |
+
user_input, science_objectives_input, context, subdomain, uploaded_pdfs,
|
582 |
+
max_tokens, temperature, top_p, frequency_penalty, presence_penalty,
|
583 |
full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html
|
584 |
]
|
585 |
)
|
586 |
|
587 |
# Launch the app
|
588 |
demo.launch(share=True)
|
589 |
+
|