|
import pandas as pd |
|
import streamlit as st |
|
from util.assistants import GPTAgent |
|
import os |
|
|
|
|
|
|
|
def generate_explanations(model_name, questions, template, temperature, max_tokens, custom_template=None): |
|
agent = GPTAgent(model_name) |
|
explanations = [] |
|
progress_bar = st.progress(0) |
|
total_questions = len(questions) |
|
|
|
for i, question in enumerate(questions): |
|
if template == "Chain of Thought": |
|
prompt = f"""Generate an explanation using the Chain of Thought template for the following question: |
|
|
|
Question: {question} |
|
|
|
Let's think step by step. |
|
|
|
Explanation: |
|
""" |
|
elif template == "Custom" and custom_template: |
|
prompt = custom_template.replace("{question}", question) |
|
else: |
|
prompt = f"""Generate an explanation for the following question: |
|
|
|
Question: {question} |
|
|
|
Explanation: |
|
""" |
|
response = agent.invoke(prompt, temperature=temperature, max_tokens=max_tokens).strip() |
|
explanations.append(response) |
|
|
|
|
|
progress_bar.progress((i + 1) / total_questions) |
|
|
|
return explanations |
|
|
|
|
|
|
|
examples = { |
|
'good': { |
|
'question': "What causes rainbows to appear in the sky?", |
|
'explanation': "Rainbows appear when sunlight is refracted, dispersed, and reflected inside water droplets in the atmosphere, resulting in a spectrum of light appearing in the sky." |
|
}, |
|
'bad': { |
|
'question': "What causes rainbows to appear in the sky?", |
|
'explanation': "Rainbows happen because light in the sky gets mixed up and sometimes shows colors when it's raining or when there is water around." |
|
} |
|
} |
|
|
|
|
|
|
|
def check_password(): |
|
def password_entered(): |
|
if password_input == os.getenv('PASSWORD'): |
|
st.session_state['password_correct'] = True |
|
else: |
|
st.error("Incorrect Password, please try again.") |
|
|
|
password_input = st.text_input("Enter Password:", type="password") |
|
submit_button = st.button("Submit", on_click=password_entered) |
|
|
|
if submit_button and not st.session_state.get('password_correct', False): |
|
st.error("Please enter a valid password to access the demo.") |
|
|
|
|
|
|
|
st.title('Explanation Generation') |
|
|
|
|
|
st.sidebar.write(""" |
|
### Welcome to the Natural Language Explanation Generation Demo |
|
This application allows you to generate high-quality explanations for various questions using different templates. Upload a CSV of questions, select an explanation template, and generate explanations. |
|
""") |
|
|
|
|
|
if not st.session_state.get('password_correct', False): |
|
check_password() |
|
else: |
|
st.sidebar.success("Password Verified. Proceed with the demo.") |
|
|
|
st.write(""" |
|
### Instructions for Uploading CSV |
|
Please upload a CSV file with the following column: |
|
- `question`: The question you want explanations for. |
|
|
|
**Example CSV Format:** |
|
""") |
|
|
|
|
|
example_data_gen = { |
|
"question": [ |
|
"What causes rainbows to appear in the sky?", |
|
"Why is the sky blue?" |
|
] |
|
} |
|
example_df_gen = pd.DataFrame(example_data_gen) |
|
st.dataframe(example_df_gen) |
|
|
|
uploaded_file_gen = st.file_uploader("Upload CSV file with 'question' column", type=['csv']) |
|
|
|
if uploaded_file_gen is not None: |
|
template = st.selectbox("Select an explanation template", ["Default", "Chain of Thought", "Custom"]) |
|
model_name = st.selectbox('Select a model:', ['gpt4-1106', 'gpt35-1106']) |
|
|
|
temperature = st.sidebar.slider('Temperature', min_value=0.0, max_value=1.0, value=0.8) |
|
max_tokens = st.sidebar.slider('Max Tokens', min_value=50, max_value=500, value=150) |
|
|
|
custom_template = "" |
|
if template == "Custom": |
|
custom_template = st.text_area("Enter your custom template", |
|
value="Generate an explanation for the following question:\n\nQuestion: {question}\n\nExplanation:") |
|
|
|
if st.button('Generate Explanations'): |
|
questions_df = pd.read_csv(uploaded_file_gen) |
|
questions = questions_df['question'].tolist() |
|
explanations = generate_explanations(model_name, questions, template, temperature, max_tokens, custom_template) |
|
|
|
result_df_gen = pd.DataFrame({ |
|
'question': questions, |
|
'explanation': explanations |
|
}) |
|
|
|
st.write('### Generated Explanations') |
|
st.dataframe(result_df_gen) |
|
|
|
|
|
csv_gen = result_df_gen.to_csv(index=False) |
|
st.download_button( |
|
label="Download generated explanations as CSV", |
|
data=csv_gen, |
|
file_name='generated_explanations.csv', |
|
mime='text/csv', |
|
) |
|
|