File size: 5,546 Bytes
4bd484d
ead511a
fd85cd0
b67323a
01aaa56
 
ead511a
4bd484d
d905796
d6d68dd
 
1f15159
 
 
 
 
 
d6d68dd
 
 
 
 
1f15159
 
 
d6d68dd
 
 
 
 
 
 
 
 
 
 
 
528891b
16a0a0c
ead511a
d905796
 
 
 
 
fd85cd0
4bd484d
9f1d09e
ead511a
f059f48
 
8d4e339
 
 
 
 
 
 
 
b67323a
 
 
 
 
 
f059f48
8d4e339
 
 
 
 
f059f48
 
 
 
 
 
b67323a
ead511a
4bd484d
a01800e
35da82e
 
d905796
35da82e
d25fa67
 
35da82e
d25fa67
 
35da82e
 
01aaa56
35da82e
b67323a
e8a2242
b67323a
01aaa56
e8a2242
 
b67323a
 
35da82e
a01800e
d25fa67
35da82e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d25fa67
e8a2242
 
 
 
35da82e
 
 
01aaa56
35da82e
a01800e
35da82e
a01800e
f059f48
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import streamlit as st
import time
import json
from datetime import datetime
from openai import OpenAI

# Streamlit Page Config
st.set_page_config(page_title="Schlager ContractAi", layout="wide")

# Authentication
AUTHORIZED_USERS = {
    "[email protected]": "Pass.123",
    "[email protected]": "Pass.123",
    "[email protected]": "Pass.123",
    "[email protected]": "Pass.123"
}

if "authenticated" not in st.session_state:
    st.session_state["authenticated"] = False

def login():
    email = st.session_state.get("email", "")
    password = st.session_state.get("password", "")
    if email in AUTHORIZED_USERS and AUTHORIZED_USERS[email] == password:
        st.session_state["authenticated"] = True
    else:
        st.error("Invalid email or password. Please try again.")

if not st.session_state["authenticated"]:
    st.title("Sign In")
    st.text_input("Email", key="email")
    st.text_input("Password", type="password", key="password")
    st.button("Login", on_click=login)
    st.stop()

# Main App
st.title("Schlager ContractAi")
st.caption("Chat with your contract or manage meeting minutes")

# Load API Key from Environment Variable
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    st.error("Missing OpenAI API key. Please set it as an environment variable.")
    st.stop()

# Tabs
tab1, tab2 = st.tabs(["Contract", "Technical"])

FLAGGED_RESPONSES_FILE = "/mnt/data/flagged_responses.json"

# Ensure the flagged responses file exists
def ensure_flagged_file():
    if not os.path.exists(FLAGGED_RESPONSES_FILE):
        with open(FLAGGED_RESPONSES_FILE, "w") as file:
            json.dump([], file)  # Initialize with an empty list

ensure_flagged_file()

def save_flagged_response(user_query, ai_response):
    flagged_data = {
        "query": user_query,
        "response": ai_response,
        "timestamp": datetime.now().isoformat()
    }
    
    with open(FLAGGED_RESPONSES_FILE, "r") as file:
        try:
            existing_data = json.load(file)
        except json.JSONDecodeError:
            existing_data = []
    
    existing_data.append(flagged_data)
    
    with open(FLAGGED_RESPONSES_FILE, "w") as file:
        json.dump(existing_data, file, indent=4)
    
    st.success("Response flagged for review.")

# Contract Chat Section
def contract_chat_section(tab, assistant_id, session_key, input_key):
    with tab:
        st.subheader("Chat")
        client = OpenAI(api_key=OPENAI_API_KEY)

        if session_key not in st.session_state:
            st.session_state[session_key] = []

        if st.button("Clear Chat", key=f"clear_chat_{session_key}"):
            st.session_state[session_key] = []
            st.rerun()

        for idx, message in enumerate(st.session_state[session_key]):
            role, content = message["role"], message["content"]
            if role == "assistant":
                with st.container():
                    st.chat_message(role).write(content)
                    if st.button("🚩 Flag", key=f"flag_{session_key}_{idx}"):
                        user_query = st.session_state[session_key][idx-1]["content"] if idx > 0 else "Unknown"
                        save_flagged_response(user_query, content)
            else:
                st.chat_message(role).write(content)

        if prompt := st.chat_input("Enter your message:", key=input_key):
            st.session_state[session_key].append({"role": "user", "content": prompt})
            st.chat_message("user").write(prompt)

            try:
                thread = client.beta.threads.create()
                thread_id = thread.id
                client.beta.threads.messages.create(
                    thread_id=thread_id,
                    role="user",
                    content=prompt
                )

                run = client.beta.threads.runs.create(
                    thread_id=thread_id,
                    assistant_id=assistant_id
                )

                while True:
                    run_status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
                    if run_status.status == "completed":
                        break
                    time.sleep(1)

                messages = client.beta.threads.messages.list(thread_id=thread_id)
                assistant_message = messages.data[0].content[0].text.value
                st.session_state[session_key].append({"role": "assistant", "content": assistant_message})
                with st.container():
                    st.chat_message("assistant").write(assistant_message)
                    if st.button("🚩 Flag", key=f"flag_response_{session_key}_{len(st.session_state[session_key])}"):
                        save_flagged_response(prompt, assistant_message)
            except Exception as e:
                st.error(f"Error: {str(e)}")

# Assign Assistant IDs
ASSISTANT_CONTRACT_ID = "asst_rd9h8PfYuOmHbkvOF3RTmVfn"
ASSISTANT_TECHNICAL_ID = "asst_xizNZBCJuy4TqdjqjwkxbAki"

contract_chat_section(tab1, ASSISTANT_CONTRACT_ID, "contract_messages", "contract_input")
contract_chat_section(tab2, ASSISTANT_TECHNICAL_ID, "technical_messages", "technical_input")

# Provide download option for flagged responses
if os.path.exists(FLAGGED_RESPONSES_FILE):
    with open(FLAGGED_RESPONSES_FILE, "r") as file:
        flagged_responses = file.read()
    st.download_button("Download Flagged Responses", data=flagged_responses, file_name="flagged_responses.json", mime="application/json")