Spaces:
Running
Running
import streamlit as st | |
import openai | |
import re | |
import csv | |
import base64 | |
from io import StringIO | |
import threading | |
from queue import Queue | |
st.title("EinfachChatProjekt") | |
api_key = st.sidebar.text_input("API Key:", value="sk-") | |
openai.api_key = api_key | |
show_notes = st.sidebar.checkbox("Show Notes", value="TRUE") | |
data_section = st.sidebar.text_area("CSV or Text Data:") | |
paste_data = st.sidebar.button("Paste Data") | |
num_concurrent_calls = st.sidebar.number_input("Concurrent Calls:", min_value=1, max_value=2000, value=50, step=1) | |
generate_all = st.sidebar.button("Generate All") | |
reset = st.sidebar.button("Reset") | |
add_row = st.sidebar.button("Add row") | |
model = st.sidebar.selectbox("Model:", ["gpt-4", "gpt-3.5-turbo"]) | |
temperature = st.sidebar.slider("Temperature:", 0.0, 1.0, 0.6, step=0.01) | |
max_tokens = st.sidebar.number_input("Max Tokens:", min_value=1, max_value=8192, value=2000, step=1) | |
top_p = st.sidebar.slider("Top P:", 0.0, 1.0, 1.0, step=0.01) | |
system_message = st.sidebar.text_area("System Message:") | |
row_count = st.session_state.get("row_count", 1) | |
if add_row: | |
row_count += 1 | |
st.session_state.row_count = row_count | |
if paste_data: | |
data = StringIO(data_section.strip()) | |
reader = csv.reader(data, delimiter='\n', quotechar='"') | |
messages = [row[0] for row in reader] | |
if show_notes: | |
row_count = len(messages) // 2 | |
for i in range(row_count): | |
st.session_state[f"note{i}"] = messages[i * 2] | |
st.session_state[f"message{i}"] = messages[i * 2 + 1] | |
else: | |
row_count = len(messages) | |
for i, message in enumerate(messages): | |
st.session_state[f"message{i}"] = message | |
st.session_state.row_count = row_count | |
if reset: | |
row_count = 1 | |
st.session_state.row_count = row_count | |
for i in range(100): # Assuming a maximum of 100 rows | |
st.session_state[f"note{i}"] = "" | |
st.session_state[f"message{i}"] = "" | |
st.session_state[f"response{i}"] = "" | |
st.session_state[f"prompt_tokens{i}"] = 0 | |
st.session_state[f"response_tokens{i}"] = 0 | |
st.session_state[f"word_count{i}"] = 0 | |
def generate_response(i, message): | |
try: | |
completion = openai.ChatCompletion.create( | |
model=model, | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": message} | |
], | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p | |
) | |
response = completion.choices[0].message.content | |
prompt_tokens = completion.usage['prompt_tokens'] | |
response_tokens = completion.usage['total_tokens'] - prompt_tokens | |
word_count = len(re.findall(r'\w+', response)) | |
return (i, response, prompt_tokens, response_tokens, word_count) | |
except Exception as e: | |
return (i, str(e), 0, 0, 0) | |
def worker(q, results): | |
for item in iter(q.get, None): | |
results.put(generate_response(*item)) | |
class WorkerThread(threading.Thread): | |
def __init__(self, input_queue, output_queue): | |
threading.Thread.__init__(self) | |
self.input_queue = input_queue | |
self.output_queue = output_queue | |
self.daemon = True | |
def run(self): | |
while True: | |
i, message = self.input_queue.get() | |
try: | |
result = generate_response(i, message) | |
self.output_queue.put(result) | |
finally: | |
self.input_queue.task_done() | |
if generate_all: | |
jobs = Queue() | |
results = Queue() | |
workers = [WorkerThread(jobs, results) for _ in range(num_concurrent_calls)] | |
for worker in workers: | |
worker.start() | |
for i in range(row_count): | |
message = st.session_state.get(f"message{i}", "") | |
jobs.put((i, message)) | |
jobs.join() | |
while not results.empty(): | |
i, response, prompt_tokens, response_tokens, word_count = results.get() | |
st.session_state[f"response{i}"] = response | |
st.session_state[f"prompt_tokens{i}"] = prompt_tokens | |
st.session_state[f"response_tokens{i}"] = response_tokens | |
st.session_state[f"word_count{i}"] = word_count | |
def create_download_link(text, filename): | |
b64 = base64.b64encode(text.encode()).decode() | |
href = f'<a href="data:file/txt;base64,{b64}" download="{filename}">Download {filename}</a>' | |
return href | |
for i in range(row_count): | |
if show_notes: | |
st.text_input(f"Note {i + 1}:", key=f"note{i}", value=st.session_state.get(f"note{i}", "")) | |
col1, col2 = st.columns(2) | |
with col1: | |
message = st.text_area(f"Message {i + 1}:", key=f"message{i}", value=st.session_state.get(f"message{i}", "")) | |
if st.button(f"Generate Response {i + 1}") and not st.session_state.get(f"response{i}", ""): | |
response, prompt_tokens, response_tokens, word_count = generate_response(i, message) | |
st.session_state[f"response{i}"] = response | |
st.session_state[f"prompt_tokens{i}"] = prompt_tokens | |
st.session_state[f"response_tokens{i}"] = response_tokens | |
st.session_state[f"word_count{i}"] = word_count | |
with col2: | |
st.text_area(f"Response {i + 1}:", value=st.session_state.get(f"response{i}", "")) | |
st.write(f"Tokens: {st.session_state.get(f'prompt_tokens{i}', 0)} / {st.session_state.get(f'response_tokens{i}', 0)} + Words: {st.session_state.get(f'word_count{i}', 0)}") | |
responses_text = "\n\n".join([f"{st.session_state.get(f'note{i}', '')}\n{st.session_state.get(f'response{i}', '')}" for i in range(row_count) if show_notes] + [st.session_state.get(f"response{i}", "") for i in range(row_count) if not show_notes]) | |
download_filename = "GPT-4 Responses.txt" | |
download_link = create_download_link(responses_text, download_filename) | |
st.markdown(download_link, unsafe_allow_html=True) |