|
import streamlit as st |
|
import os |
|
import openai |
|
|
|
def get_completion(client, model_id, messages, args): |
|
completion_args = { |
|
"model": model_id, |
|
"messages": messages, |
|
"frequency_penalty": args.frequency_penalty, |
|
"max_tokens": args.max_tokens, |
|
"n": args.n, |
|
"presence_penalty": args.presence_penalty, |
|
"seed": args.seed, |
|
"stop": args.stop, |
|
"stream": args.stream, |
|
"temperature": args.temperature, |
|
"top_p": args.top_p, |
|
} |
|
|
|
completion_args = { |
|
k: v for k, v in completion_args.items() if v is not None |
|
} |
|
|
|
try: |
|
response = client.chat.completions.create(**completion_args) |
|
return response |
|
except Exception as e: |
|
print(f"Error during API call: {e}") |
|
return None |
|
|
|
|
|
st.set_page_config(page_title="Turing Test") |
|
|
|
|
|
with st.sidebar: |
|
st.title('π¦π¬ Welcome to Turing Test') |
|
|
|
|
|
openai_api_key = "super-secret-token" |
|
|
|
|
|
os.environ['OPENAI_API_KEY'] = openai_api_key |
|
|
|
openai.api_key = openai_api_key |
|
openai.api_base = "https://turingtest--example-vllm-openai-compatible-serve.modal.run/v1" |
|
|
|
client = openai.OpenAI(api_key=openai_api_key, base_url=openai.api_base) |
|
|
|
|
|
st.subheader('System Prompt') |
|
system_prompt = st.text_area("Enter a system prompt:", |
|
"you are rolplaying as an old grandma", |
|
help="This message sets the behavior of the AI.") |
|
st.subheader('Models and parameters') |
|
selected_model = st.sidebar.selectbox('Choose a model', ['meta-llama/Meta-Llama-3.1-8B-Instruct'], key='selected_model') |
|
temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.8, step=0.1) |
|
top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.95, step=0.01) |
|
max_length = st.sidebar.slider('max_length', min_value=32, max_value=1024, value=32, step=8) |
|
|
|
|
|
|
|
if "messages" not in st.session_state.keys(): |
|
st.session_state.messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "assistant", "content": "Hello!"} |
|
] |
|
|
|
|
|
for message in st.session_state.messages[1:]: |
|
with st.chat_message(message["role"]): |
|
st.write(message["content"]) |
|
|
|
def clear_chat_history(): |
|
st.session_state.messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "assistant", "content": "Hello!"} |
|
] |
|
st.sidebar.button('Clear Chat History', on_click=clear_chat_history) |
|
|
|
|
|
def generate_llama2_response(prompt_input, model, temperature, top_p, max_length): |
|
|
|
class Args: |
|
def __init__(self): |
|
self.frequency_penalty = 0 |
|
self.max_tokens = max_length |
|
self.n = 1 |
|
self.presence_penalty = 0 |
|
self.seed = 42 |
|
self.stop = None |
|
self.stream = False |
|
self.temperature = temperature |
|
self.top_p = top_p |
|
|
|
args = Args() |
|
|
|
|
|
st.session_state.messages[0] = {"role": "system", "content": system_prompt} |
|
|
|
response = get_completion(client, model, st.session_state.messages, args) |
|
|
|
if response: |
|
return response.choices[0].message.content |
|
else: |
|
return "Sorry, there was an error generating a response." |
|
|
|
|
|
if prompt := st.chat_input(): |
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.write(prompt) |
|
|
|
|
|
if st.session_state.messages[-1]["role"] != "assistant": |
|
with st.chat_message("assistant"): |
|
with st.spinner("Thinking..."): |
|
response = generate_llama2_response(prompt, selected_model, temperature, top_p, max_length) |
|
placeholder = st.empty() |
|
full_response = '' |
|
for item in response: |
|
full_response += item |
|
placeholder.markdown(full_response) |
|
message = {"role": "assistant", "content": full_response} |
|
st.session_state.messages.append(message) |