aquibmoin commited on
Commit
0e05b66
·
verified ·
1 Parent(s): d9abe62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -55
app.py CHANGED
@@ -14,6 +14,8 @@ import tempfile
14
  from astroquery.nasa_ads import ADS
15
  import pyvo as vo
16
  import pandas as pd
 
 
17
 
18
  # Load the NASA-specific bi-encoder model and tokenizer
19
  bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
@@ -71,54 +73,105 @@ def encode_text(text):
71
  outputs = bi_model(**inputs)
72
  return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()
73
 
74
- def get_chunks(text, chunk_size=300):
75
  """
76
- Split a long piece of text into smaller chunks of approximately 'chunk_size' characters.
 
77
  """
78
  if not text.strip():
79
- raise ValueError("The provided context is empty or blank.")
80
-
81
  # Split the text into chunks of approximately 'chunk_size' characters
82
  chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
 
83
  return chunks
84
 
85
- def retrieve_relevant_context(user_input, context_texts, chunk_size=300, similarity_threshold=0.3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  """
87
- Split the context text into smaller chunks, find the most relevant chunk
88
- using cosine similarity, and return the most relevant chunk.
89
- If no chunk meets the similarity threshold, return a fallback message.
90
  """
91
- # Check if the context is empty or just whitespace
92
- if not context_texts.strip():
93
- return "Error: Context is empty or improperly formatted.", None
94
-
95
- # Split the long context text into chunks using the chunking function
96
- context_chunks = get_chunks(context_texts, chunk_size)
97
-
98
- # Handle single context case
99
- if len(context_chunks) == 1:
100
- return context_chunks[0], 1.0 # Return the single chunk with perfect similarity
101
-
102
- # Encode the user input to create a query embedding
103
- user_embedding = encode_text(user_input).reshape(1, -1)
104
-
105
- # Encode all context chunks to create embeddings
106
- chunk_embeddings = np.array([encode_text(chunk) for chunk in context_chunks])
107
-
108
- # Compute cosine similarity between the user input and each chunk
109
- similarities = cosine_similarity(user_embedding, chunk_embeddings).flatten()
110
-
111
- # Check if any similarity scores are above the threshold
112
- if max(similarities) < similarity_threshold:
113
- return "No relevant context found for the user input.", None
114
-
115
- # Identify the most relevant chunk based on the highest cosine similarity score
116
- most_relevant_idx = np.argmax(similarities)
117
- most_relevant_chunk = context_chunks[most_relevant_idx]
118
 
119
- # Return the most relevant chunk and the similarity score
120
- return most_relevant_chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
 
 
 
 
 
 
 
 
122
 
123
  def extract_keywords_with_gpt(user_input, max_tokens=100, temperature=0.3):
124
  # Define a prompt to ask GPT-4 to extract keywords and important terms
@@ -367,20 +420,21 @@ def gpt_response_to_dataframe(gpt_response):
367
  return df
368
 
369
  def chatbot(user_input, science_objectives="", context="", subdomain="", use_encoder=False, max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
370
- if use_encoder and context:
371
- context_texts = context
372
- relevant_context = retrieve_relevant_context(user_input, context_texts)
373
- else:
374
- relevant_context = ""
 
375
 
376
  # Fetch NASA ADS references using the full prompt
377
  references = fetch_nasa_ads_references(subdomain)
378
 
379
- # Generate response from GPT-4
380
  response = generate_response(
381
  user_input=user_input,
382
- science_objectives=science_objectives, # Pass Science Objectives
383
- relevant_context=relevant_context, # Pass retrieved context (if any)
384
  references=references,
385
  max_tokens=max_tokens,
386
  temperature=temperature,
@@ -389,6 +443,7 @@ def chatbot(user_input, science_objectives="", context="", subdomain="", use_enc
389
  presence_penalty=presence_penalty
390
  )
391
 
 
392
  if science_objectives.strip():
393
  response = f"### Science Objectives (User-Defined):\n\n{science_objectives}\n\n" + response
394
 
@@ -443,13 +498,16 @@ def chatbot(user_input, science_objectives="", context="", subdomain="", use_enc
443
  return full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html
444
 
445
  with gr.Blocks() as demo:
446
- gr.Markdown("# ExosAI - NASA SMD SCDD Generator with RAG [version-1.1]")
447
-
448
  # User Inputs
449
  user_input = gr.Textbox(lines=5, placeholder="Enter your Science Goal...", label="Science Goal")
450
  context = gr.Textbox(lines=10, placeholder="Enter Context Text...", label="Context")
451
  subdomain = gr.Textbox(lines=2, placeholder="Define your Subdomain...", label="Subdomain Definition")
452
 
 
 
 
453
  # Science Objectives Button & Input (Initially Hidden)
454
  science_objectives_button = gr.Button("Manually Enter Science Objectives")
455
  science_objectives_input = gr.Textbox(
@@ -459,15 +517,14 @@ with gr.Blocks() as demo:
459
  visible=False # Initially hidden
460
  )
461
 
462
- # Define event inside Blocks (Fix for the Error)
463
  science_objectives_button.click(
464
  fn=lambda: gr.update(visible=True), # Show textbox when clicked
465
  inputs=[],
466
  outputs=[science_objectives_input]
467
  )
468
 
469
- # More Inputs
470
- use_encoder = gr.Checkbox(label="Use NASA SMD Bi-Encoder for Context")
471
  max_tokens = gr.Slider(50, 2000, value=150, step=10, label="Max Tokens")
472
  temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature")
473
  top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.1, label="Top-p")
@@ -490,8 +547,8 @@ with gr.Blocks() as demo:
490
  submit_button.click(
491
  fn=chatbot,
492
  inputs=[
493
- user_input, science_objectives_input, context, subdomain,
494
- use_encoder, max_tokens, temperature, top_p, frequency_penalty, presence_penalty
495
  ],
496
  outputs=[full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html]
497
  )
@@ -503,7 +560,7 @@ with gr.Blocks() as demo:
503
  "", # science_objectives_input
504
  "", # context
505
  "", # subdomain
506
- False, # use_encoder
507
  150, # max_tokens
508
  0.7, # temperature
509
  0.9, # top_p
@@ -521,11 +578,12 @@ with gr.Blocks() as demo:
521
  fn=clear_all,
522
  inputs=[],
523
  outputs=[
524
- user_input, science_objectives_input, context, subdomain,
525
- use_encoder, max_tokens, temperature, top_p, frequency_penalty, presence_penalty,
526
  full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html
527
  ]
528
  )
529
 
530
  # Launch the app
531
  demo.launch(share=True)
 
 
14
  from astroquery.nasa_ads import ADS
15
  import pyvo as vo
16
  import pandas as pd
17
+ import faiss
18
+ from PyPDF2 import PdfReader
19
 
20
  # Load the NASA-specific bi-encoder model and tokenizer
21
  bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
 
73
  outputs = bi_model(**inputs)
74
  return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()
75
 
76
+ def get_chunks(text, chunk_size=500):
77
  """
78
+ Splits a long text into smaller chunks of approximately 'chunk_size' characters.
79
+ Ensures that chunks do not cut off words abruptly.
80
  """
81
  if not text.strip():
82
+ raise ValueError("The provided text is empty or blank.")
83
+
84
  # Split the text into chunks of approximately 'chunk_size' characters
85
  chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
86
+
87
  return chunks
88
 
89
+ # Initialize FAISS index with cosine similarity
90
+
91
+ embedding_dim = 768 # NASA Bi-Encoder outputs 768-dimensional embeddings
92
+ index = faiss.IndexFlatIP(embedding_dim) # FAISS inner product (cosine similarity)
93
+ pdf_chunks = [] # Store extracted chunks for later reference
94
+ chunk_embeddings = [] # Store embeddings for similarity retrieval
95
+
96
+ def load_and_process_uploaded_pdfs(pdf_files):
97
+ """Extracts text from uploaded PDFs, splits into chunks, generates embeddings, and stores in FAISS."""
98
+ global index, pdf_chunks, chunk_embeddings
99
+
100
+ # Reset the FAISS index and stored data for new uploads
101
+ index.reset()
102
+ pdf_chunks.clear()
103
+ chunk_embeddings.clear()
104
+
105
+ text_data = []
106
+
107
+ for pdf_file in pdf_files:
108
+ if pdf_file is None:
109
+ continue # Skip if no file is uploaded
110
+
111
+ reader = PdfReader(pdf_file)
112
+ pdf_text = ""
113
+ for page in reader.pages:
114
+ pdf_text += page.extract_text() + "\n"
115
+
116
+ # Split extracted text into chunks
117
+ chunks = get_chunks(pdf_text, chunk_size=500) # Adjust chunk size if needed
118
+ pdf_chunks.extend(chunks) # Store for retrieval
119
+
120
+ # Generate embeddings for each chunk and store in FAISS
121
+ for chunk in chunks:
122
+ chunk_embedding = encode_text(chunk).reshape(1, -1)
123
+
124
+ # Normalize the embedding for cosine similarity
125
+ chunk_embedding = chunk_embedding / np.linalg.norm(chunk_embedding)
126
+
127
+ index.add(chunk_embedding) # Add normalized embeddings to FAISS
128
+ chunk_embeddings.append(chunk_embedding) # Store for reference
129
+
130
+ text_data.extend(chunks)
131
+
132
+ return text_data
133
+
134
+
135
+ def retrieve_relevant_context(user_input, context_text, science_objectives="", k=3):
136
  """
137
+ Retrieve the most relevant document chunks using cosine similarity search.
138
+ Uses combined user inputs (Science Goal + Context + Optional Science Objectives).
 
139
  """
140
+ global chunk_embeddings, pdf_chunks # Ensure we're using the globally stored embeddings and text chunks
141
+
142
+ # Combine all user inputs into a single query
143
+ query_text = f"Science Goal: {user_input}\nContext: {context_text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ # Append Science Objectives only if provided
146
+ if science_objectives.strip():
147
+ query_text += f"\nScience Objectives: {science_objectives}"
148
+
149
+ # Generate query embedding
150
+ query_embedding = encode_text(query_text).reshape(1, -1)
151
+
152
+ # Normalize the query embedding for cosine similarity
153
+ query_embedding = query_embedding / np.linalg.norm(query_embedding)
154
+
155
+ # Convert stored chunk embeddings into a NumPy array
156
+ if len(chunk_embeddings) == 0:
157
+ return "No preloaded document data available.", None
158
+
159
+ chunk_embeddings_array = np.array(chunk_embeddings).reshape(len(chunk_embeddings), -1)
160
+
161
+ # Compute cosine similarity between the query and all stored chunk embeddings
162
+ similarities = cosine_similarity(query_embedding, chunk_embeddings_array).flatten()
163
+
164
+ # Get indices of top k most relevant chunks
165
+ top_indices = similarities.argsort()[-k:][::-1] # Sort in descending order
166
 
167
+ # Retrieve the most relevant chunks
168
+ retrieved_context = "\n\n".join([pdf_chunks[i] for i in top_indices])
169
+
170
+ # If no relevant chunk is found, return a default message
171
+ if not retrieved_context.strip():
172
+ return "No relevant context found for the query."
173
+
174
+ return retrieved_context
175
 
176
  def extract_keywords_with_gpt(user_input, max_tokens=100, temperature=0.3):
177
  # Define a prompt to ask GPT-4 to extract keywords and important terms
 
420
  return df
421
 
422
  def chatbot(user_input, science_objectives="", context="", subdomain="", use_encoder=False, max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
423
+ """
424
+ Handles the full workflow: retrieves relevant context, generates response, processes output.
425
+ """
426
+
427
+ # Retrieve relevant context from FAISS using all user inputs
428
+ relevant_context = retrieve_relevant_context(user_input, context, science_objectives)
429
 
430
  # Fetch NASA ADS references using the full prompt
431
  references = fetch_nasa_ads_references(subdomain)
432
 
433
+ # Generate response from GPT-4, ensuring we pass all relevant inputs
434
  response = generate_response(
435
  user_input=user_input,
436
+ science_objectives=science_objectives,
437
+ relevant_context=relevant_context, # Ensure retrieved FAISS context is passed
438
  references=references,
439
  max_tokens=max_tokens,
440
  temperature=temperature,
 
443
  presence_penalty=presence_penalty
444
  )
445
 
446
+ # Append manually entered science objectives to the response (if provided)
447
  if science_objectives.strip():
448
  response = f"### Science Objectives (User-Defined):\n\n{science_objectives}\n\n" + response
449
 
 
498
  return full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html
499
 
500
  with gr.Blocks() as demo:
501
+ gr.Markdown("# ExosAI - NASA SMD SCDD Generator with RAG [version-1.2]")
502
+
503
  # User Inputs
504
  user_input = gr.Textbox(lines=5, placeholder="Enter your Science Goal...", label="Science Goal")
505
  context = gr.Textbox(lines=10, placeholder="Enter Context Text...", label="Context")
506
  subdomain = gr.Textbox(lines=2, placeholder="Define your Subdomain...", label="Subdomain Definition")
507
 
508
+ # PDF Upload Section (Up to 3 PDFs)
509
+ uploaded_pdfs = gr.File(file_types=[".pdf"], label="Upload Reference PDFs (Up to 3)", interactive=True, multiple=True)
510
+
511
  # Science Objectives Button & Input (Initially Hidden)
512
  science_objectives_button = gr.Button("Manually Enter Science Objectives")
513
  science_objectives_input = gr.Textbox(
 
517
  visible=False # Initially hidden
518
  )
519
 
520
+ # Event to Show Science Objectives Input
521
  science_objectives_button.click(
522
  fn=lambda: gr.update(visible=True), # Show textbox when clicked
523
  inputs=[],
524
  outputs=[science_objectives_input]
525
  )
526
 
527
+ # Additional Model Parameters
 
528
  max_tokens = gr.Slider(50, 2000, value=150, step=10, label="Max Tokens")
529
  temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature")
530
  top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.1, label="Top-p")
 
547
  submit_button.click(
548
  fn=chatbot,
549
  inputs=[
550
+ user_input, science_objectives_input, context, subdomain, uploaded_pdfs,
551
+ max_tokens, temperature, top_p, frequency_penalty, presence_penalty
552
  ],
553
  outputs=[full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html]
554
  )
 
560
  "", # science_objectives_input
561
  "", # context
562
  "", # subdomain
563
+ None, # uploaded_pdfs
564
  150, # max_tokens
565
  0.7, # temperature
566
  0.9, # top_p
 
578
  fn=clear_all,
579
  inputs=[],
580
  outputs=[
581
+ user_input, science_objectives_input, context, subdomain, uploaded_pdfs,
582
+ max_tokens, temperature, top_p, frequency_penalty, presence_penalty,
583
  full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html
584
  ]
585
  )
586
 
587
  # Launch the app
588
  demo.launch(share=True)
589
+