Check / app.py
Alfasign's picture
app.py
a499ce9
raw
history blame
5.85 kB
import streamlit as st
import openai
import re
import csv
import base64
from io import StringIO
import threading
from queue import Queue
st.title("EinfachChatProjekt")
api_key = st.sidebar.text_input("API Key:", value="sk-")
openai.api_key = api_key
show_notes = st.sidebar.checkbox("Show Notes", value="TRUE")
data_section = st.sidebar.text_area("CSV or Text Data:")
paste_data = st.sidebar.button("Paste Data")
num_concurrent_calls = st.sidebar.number_input("Concurrent Calls:", min_value=1, max_value=2000, value=50, step=1)
generate_all = st.sidebar.button("Generate All")
reset = st.sidebar.button("Reset")
add_row = st.sidebar.button("Add row")
model = st.sidebar.selectbox("Model:", ["gpt-4", "gpt-3.5-turbo"])
temperature = st.sidebar.slider("Temperature:", 0.0, 1.0, 0.6, step=0.01)
max_tokens = st.sidebar.number_input("Max Tokens:", min_value=1, max_value=8192, value=2000, step=1)
top_p = st.sidebar.slider("Top P:", 0.0, 1.0, 1.0, step=0.01)
system_message = st.sidebar.text_area("System Message:")
row_count = st.session_state.get("row_count", 1)
if add_row:
row_count += 1
st.session_state.row_count = row_count
if paste_data:
data = StringIO(data_section.strip())
reader = csv.reader(data, delimiter='\n', quotechar='"')
messages = [row[0] for row in reader]
if show_notes:
row_count = len(messages) // 2
for i in range(row_count):
st.session_state[f"note{i}"] = messages[i * 2]
st.session_state[f"message{i}"] = messages[i * 2 + 1]
else:
row_count = len(messages)
for i, message in enumerate(messages):
st.session_state[f"message{i}"] = message
st.session_state.row_count = row_count
if reset:
row_count = 1
st.session_state.row_count = row_count
for i in range(100): # Assuming a maximum of 100 rows
st.session_state[f"note{i}"] = ""
st.session_state[f"message{i}"] = ""
st.session_state[f"response{i}"] = ""
st.session_state[f"prompt_tokens{i}"] = 0
st.session_state[f"response_tokens{i}"] = 0
st.session_state[f"word_count{i}"] = 0
def generate_response(i, message):
try:
completion = openai.ChatCompletion.create(
model=model,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": message}
],
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p
)
response = completion.choices[0].message.content
prompt_tokens = completion.usage['prompt_tokens']
response_tokens = completion.usage['total_tokens'] - prompt_tokens
word_count = len(re.findall(r'\w+', response))
return (i, response, prompt_tokens, response_tokens, word_count)
except Exception as e:
return (i, str(e), 0, 0, 0)
def worker(q, results):
for item in iter(q.get, None):
results.put(generate_response(*item))
class WorkerThread(threading.Thread):
def __init__(self, input_queue, output_queue):
threading.Thread.__init__(self)
self.input_queue = input_queue
self.output_queue = output_queue
self.daemon = True
def run(self):
while True:
i, message = self.input_queue.get()
try:
result = generate_response(i, message)
self.output_queue.put(result)
finally:
self.input_queue.task_done()
if generate_all:
jobs = Queue()
results = Queue()
workers = [WorkerThread(jobs, results) for _ in range(num_concurrent_calls)]
for worker in workers:
worker.start()
for i in range(row_count):
message = st.session_state.get(f"message{i}", "")
jobs.put((i, message))
jobs.join()
while not results.empty():
i, response, prompt_tokens, response_tokens, word_count = results.get()
st.session_state[f"response{i}"] = response
st.session_state[f"prompt_tokens{i}"] = prompt_tokens
st.session_state[f"response_tokens{i}"] = response_tokens
st.session_state[f"word_count{i}"] = word_count
def create_download_link(text, filename):
b64 = base64.b64encode(text.encode()).decode()
href = f'<a href="data:file/txt;base64,{b64}" download="{filename}">Download {filename}</a>'
return href
for i in range(row_count):
if show_notes:
st.text_input(f"Note {i + 1}:", key=f"note{i}", value=st.session_state.get(f"note{i}", ""))
col1, col2 = st.columns(2)
with col1:
message = st.text_area(f"Message {i + 1}:", key=f"message{i}", value=st.session_state.get(f"message{i}", ""))
if st.button(f"Generate Response {i + 1}") and not st.session_state.get(f"response{i}", ""):
response, prompt_tokens, response_tokens, word_count = generate_response(i, message)
st.session_state[f"response{i}"] = response
st.session_state[f"prompt_tokens{i}"] = prompt_tokens
st.session_state[f"response_tokens{i}"] = response_tokens
st.session_state[f"word_count{i}"] = word_count
with col2:
st.text_area(f"Response {i + 1}:", value=st.session_state.get(f"response{i}", ""))
st.write(f"Tokens: {st.session_state.get(f'prompt_tokens{i}', 0)} / {st.session_state.get(f'response_tokens{i}', 0)} + Words: {st.session_state.get(f'word_count{i}', 0)}")
responses_text = "\n\n".join([f"{st.session_state.get(f'note{i}', '')}\n{st.session_state.get(f'response{i}', '')}" for i in range(row_count) if show_notes] + [st.session_state.get(f"response{i}", "") for i in range(row_count) if not show_notes])
download_filename = "GPT-4 Responses.txt"
download_link = create_download_link(responses_text, download_filename)
st.markdown(download_link, unsafe_allow_html=True)