File size: 3,761 Bytes
d349925
5cdeca9
bee10b7
5cdeca9
1d831e0
 
d349925
927f7a3
 
11d423c
 
 
1d831e0
 
39d59b3
d349925
be1c6e7
 
e2c3d90
d349925
11d423c
 
a40023d
d349925
1d831e0
a40023d
1d831e0
a40023d
1d831e0
 
 
 
701b0ff
e2c3d90
1af8607
e2c3d90
1af8607
76e9582
39d59b3
1d831e0
 
e2c3d90
1d831e0
 
0f87001
 
 
 
 
4295e2c
bc9946f
e2c3d90
0f87001
e2c3d90
 
 
 
 
0f87001
11d423c
d349925
 
 
 
1d831e0
be1c6e7
 
684b0cb
701b0ff
0f87001
 
 
 
1d831e0
be1c6e7
1af8607
 
d349925
 
 
a40023d
 
b3a1556
76e9582
11d423c
4295e2c
927f7a3
5cdeca9
4815dab
1d831e0
7acb493
e2c3d90
0f87001
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from openai import OpenAI
import os
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Load the NASA-specific bi-encoder model and tokenizer
bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_model_name)
bi_model = AutoModel.from_pretrained(bi_encoder_model_name)

# Set up OpenAI client
api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=api_key)

# Define a system message to introduce ExosAI
system_message = "You are ExosAI, a helpful assistant specializing in Astrophysics and Exoplanet research. Provide detailed and accurate responses related to Astrophysics and Exoplanet research."

def encode_text(text):
    inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
    outputs = bi_model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()  # Ensure the output is 2D

def retrieve_relevant_context(user_input, context_texts):
    user_embedding = encode_text(user_input).reshape(1, -1)
    context_embeddings = np.array([encode_text(text) for text in context_texts])
    context_embeddings = context_embeddings.reshape(len(context_embeddings), -1)  # Flatten each embedding
    similarities = cosine_similarity(user_embedding, context_embeddings).flatten()
    most_relevant_idx = np.argmax(similarities)
    return context_texts[most_relevant_idx]

def generate_response(user_input, relevant_context="", max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
    if relevant_context:
        combined_input = f"Context: {relevant_context}\nQuestion: {user_input}\nAnswer:"
    else:
        combined_input = f"Question: {user_input}\nAnswer:"
    
    response = client.chat.completions.create(
        model="gpt-4",
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": combined_input}
        ],
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        frequency_penalty=frequency_penalty,
        presence_penalty=presence_penalty
    )
    return response.choices[0].message.content.strip()

def chatbot(user_input, context="", use_encoder=False, max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
    if use_encoder and context:
        context_texts = context.split("\n")
        relevant_context = retrieve_relevant_context(user_input, context_texts)
    else:
        relevant_context = ""
    response = generate_response(user_input, relevant_context, max_tokens, temperature, top_p, frequency_penalty, presence_penalty)
    return response

# Create the Gradio interface
iface = gr.Interface(
    fn=chatbot,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your Science Question here...", label="Prompt ExosAI"),
        gr.Textbox(lines=5, placeholder="Enter some context here...", label="Context"),
        gr.Checkbox(label="Use NASA SMD Bi-Encoder for Context"),
        gr.Slider(50, 500, value=150, step=10, label="Max Tokens"),
        gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.0, 1.0, value=0.9, step=0.1, label="Top-p"),
        gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Frequency Penalty"),
        gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Presence Penalty")
    ],
    outputs=gr.Textbox(label="ExosAI response..."),
    title="ExosAI - NASA SMD SCDD Generator",
    description="ExosAI is a helpful AI assistant for the automated generation of Science Case Development Documents",
)

# Launch the interface
iface.launch(share=True)