File size: 3,765 Bytes
d349925
5cdeca9
bee10b7
5cdeca9
1d831e0
 
d349925
927f7a3
 
11d423c
 
 
1d831e0
 
39d59b3
d349925
4a2c50a
dd901e6
e2c3d90
d349925
11d423c
 
a40023d
d349925
1d831e0
a40023d
1d831e0
a40023d
1d831e0
 
 
 
701b0ff
e2c3d90
1af8607
e2c3d90
1af8607
76e9582
39d59b3
34ea422
1d831e0
e2c3d90
1d831e0
 
0f87001
 
 
 
 
4295e2c
bc9946f
e2c3d90
0f87001
e2c3d90
 
 
 
 
0f87001
11d423c
d349925
 
 
 
1d831e0
4a2c50a
 
684b0cb
701b0ff
0f87001
 
 
 
1d831e0
4a2c50a
 
 
d349925
 
 
a40023d
 
b3a1556
76e9582
11d423c
4295e2c
927f7a3
5cdeca9
4815dab
1d831e0
7acb493
e2c3d90
0f87001
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from openai import OpenAI
import os
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# Load the NASA-specific bi-encoder model and tokenizer
bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_model_name)
bi_model = AutoModel.from_pretrained(bi_encoder_model_name)

# Set up OpenAI client
api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=api_key)

# Define a system message to introduce Exos
system_message = "You are Exos, a helpful assistant specializing in Exoplanet research. Provide detailed and accurate responses related to Exoplanet research."

def encode_text(text):
    inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
    outputs = bi_model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()  # Ensure the output is 2D

def retrieve_relevant_context(user_input, context_texts):
    user_embedding = encode_text(user_input).reshape(1, -1)
    context_embeddings = np.array([encode_text(text) for text in context_texts])
    context_embeddings = context_embeddings.reshape(len(context_embeddings), -1)  # Flatten each embedding
    similarities = cosine_similarity(user_embedding, context_embeddings).flatten()
    most_relevant_idx = np.argmax(similarities)
    return context_texts[most_relevant_idx]

def generate_response(user_input, relevant_context="", max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
    if relevant_context:
        combined_input = f"Context: {relevant_context}\nQuestion: {user_input}\nAnswer:"
    else:
        combined_input = f"Question: {user_input}\nAnswer:"
    
    response = client.chat.completions.create(
        model="gpt-4-turbo",
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": combined_input}
        ],
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        frequency_penalty=frequency_penalty,
        presence_penalty=presence_penalty
    )
    return response.choices[0].message.content.strip()

def chatbot(user_input, context="", use_encoder=False, max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
    if use_encoder and context:
        context_texts = context.split("\n")
        relevant_context = retrieve_relevant_context(user_input, context_texts)
    else:
        relevant_context = ""
    response = generate_response(user_input, relevant_context, max_tokens, temperature, top_p, frequency_penalty, presence_penalty)
    return response

# Create the Gradio interface
iface = gr.Interface(
    fn=chatbot,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your message here...", label="Your Question"),
        gr.Textbox(lines=5, placeholder="Enter context here, separated by new lines...", label="Context"),
        gr.Checkbox(label="Use NASA SMD Bi-Encoder for Context"),
        gr.Slider(50, 500, value=150, step=10, label="Max Tokens"),
        gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.0, 1.0, value=0.9, step=0.1, label="Top-p"),
        gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Frequency Penalty"),
        gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Presence Penalty")
    ],
    outputs=gr.Textbox(label="Exos says..."),
    title="Exos - Your Exoplanet Research Assistant",
    description="Exos is a helpful assistant specializing in Exoplanet research. Provide context to get more refined and relevant responses.",
)

# Launch the interface
iface.launch(share=True)