|
import os |
|
import streamlit as st |
|
import time |
|
import json |
|
from datetime import datetime |
|
from openai import OpenAI |
|
|
|
|
|
st.set_page_config(page_title="Schlager ContractAi", layout="wide") |
|
|
|
|
|
AUTHORIZED_USERS = { |
|
"[email protected]": "Pass.123", |
|
"[email protected]": "Pass.123", |
|
"[email protected]": "Pass.123", |
|
"[email protected]": "Pass.123" |
|
} |
|
|
|
if "authenticated" not in st.session_state: |
|
st.session_state["authenticated"] = False |
|
|
|
def login(): |
|
email = st.session_state.get("email", "") |
|
password = st.session_state.get("password", "") |
|
if email in AUTHORIZED_USERS and AUTHORIZED_USERS[email] == password: |
|
st.session_state["authenticated"] = True |
|
else: |
|
st.error("Invalid email or password. Please try again.") |
|
|
|
if not st.session_state["authenticated"]: |
|
st.title("Sign In") |
|
st.text_input("Email", key="email") |
|
st.text_input("Password", type="password", key="password") |
|
st.button("Login", on_click=login) |
|
st.stop() |
|
|
|
|
|
st.title("Schlager ContractAi") |
|
st.caption("Chat with your contract or manage meeting minutes") |
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
if not OPENAI_API_KEY: |
|
st.error("Missing OpenAI API key. Please set it as an environment variable.") |
|
st.stop() |
|
|
|
|
|
tab1, tab2 = st.tabs(["Contract", "Technical"]) |
|
|
|
FLAGGED_RESPONSES_FILE = "/mnt/data/flagged_responses.json" |
|
|
|
|
|
def ensure_flagged_file(): |
|
if not os.path.exists(FLAGGED_RESPONSES_FILE): |
|
with open(FLAGGED_RESPONSES_FILE, "w") as file: |
|
json.dump([], file) |
|
|
|
ensure_flagged_file() |
|
|
|
def save_flagged_response(user_query, ai_response): |
|
flagged_data = { |
|
"query": user_query, |
|
"response": ai_response, |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|
|
with open(FLAGGED_RESPONSES_FILE, "r") as file: |
|
try: |
|
existing_data = json.load(file) |
|
except json.JSONDecodeError: |
|
existing_data = [] |
|
|
|
existing_data.append(flagged_data) |
|
|
|
with open(FLAGGED_RESPONSES_FILE, "w") as file: |
|
json.dump(existing_data, file, indent=4) |
|
|
|
st.success("Response flagged for review.") |
|
|
|
|
|
def contract_chat_section(tab, assistant_id, session_key, input_key): |
|
with tab: |
|
st.subheader("Chat") |
|
client = OpenAI(api_key=OPENAI_API_KEY) |
|
|
|
if session_key not in st.session_state: |
|
st.session_state[session_key] = [] |
|
|
|
if st.button("Clear Chat", key=f"clear_chat_{session_key}"): |
|
st.session_state[session_key] = [] |
|
st.rerun() |
|
|
|
for idx, message in enumerate(st.session_state[session_key]): |
|
role, content = message["role"], message["content"] |
|
if role == "assistant": |
|
with st.container(): |
|
st.chat_message(role).write(content) |
|
if st.button("🚩 Flag", key=f"flag_{session_key}_{idx}"): |
|
user_query = st.session_state[session_key][idx-1]["content"] if idx > 0 else "Unknown" |
|
save_flagged_response(user_query, content) |
|
else: |
|
st.chat_message(role).write(content) |
|
|
|
if prompt := st.chat_input("Enter your message:", key=input_key): |
|
st.session_state[session_key].append({"role": "user", "content": prompt}) |
|
st.chat_message("user").write(prompt) |
|
|
|
try: |
|
thread = client.beta.threads.create() |
|
thread_id = thread.id |
|
client.beta.threads.messages.create( |
|
thread_id=thread_id, |
|
role="user", |
|
content=prompt |
|
) |
|
|
|
run = client.beta.threads.runs.create( |
|
thread_id=thread_id, |
|
assistant_id=assistant_id |
|
) |
|
|
|
while True: |
|
run_status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id) |
|
if run_status.status == "completed": |
|
break |
|
time.sleep(1) |
|
|
|
messages = client.beta.threads.messages.list(thread_id=thread_id) |
|
assistant_message = messages.data[0].content[0].text.value |
|
st.session_state[session_key].append({"role": "assistant", "content": assistant_message}) |
|
with st.container(): |
|
st.chat_message("assistant").write(assistant_message) |
|
if st.button("🚩 Flag", key=f"flag_response_{session_key}_{len(st.session_state[session_key])}"): |
|
save_flagged_response(prompt, assistant_message) |
|
except Exception as e: |
|
st.error(f"Error: {str(e)}") |
|
|
|
|
|
ASSISTANT_CONTRACT_ID = "asst_rd9h8PfYuOmHbkvOF3RTmVfn" |
|
ASSISTANT_TECHNICAL_ID = "asst_xizNZBCJuy4TqdjqjwkxbAki" |
|
|
|
contract_chat_section(tab1, ASSISTANT_CONTRACT_ID, "contract_messages", "contract_input") |
|
contract_chat_section(tab2, ASSISTANT_TECHNICAL_ID, "technical_messages", "technical_input") |
|
|
|
|
|
if os.path.exists(FLAGGED_RESPONSES_FILE): |
|
with open(FLAGGED_RESPONSES_FILE, "r") as file: |
|
flagged_responses = file.read() |
|
st.download_button("Download Flagged Responses", data=flagged_responses, file_name="flagged_responses.json", mime="application/json") |
|
|