AbenzaFran commited on
Commit
91b4d66
Β·
verified Β·
1 Parent(s): f7e68f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -46
app.py CHANGED
@@ -1,66 +1,76 @@
1
  import os
2
  import re
 
3
  from dotenv import load_dotenv
4
- load_dotenv()
5
-
6
- import gradio as gr
7
-
8
  from langchain.agents.openai_assistant import OpenAIAssistantRunnable
9
- from langchain.schema import HumanMessage, AIMessage
10
 
11
- api_key = os.getenv('OPENAI_API_KEY')
12
- extractor_agent = os.getenv('ASSISTANT_ID_SOLUTION_SPECIFIER_A')
 
 
13
 
14
- # Create the assistant. By default, we don't specify a thread_id,
15
- # so the first call that doesn't pass one will create a new thread.
16
  extractor_llm = OpenAIAssistantRunnable(
17
  assistant_id=extractor_agent,
18
  api_key=api_key,
19
  as_agent=True
20
  )
21
 
22
- # We will store thread_id globally or in a session variable.
23
- THREAD_ID = None
24
-
25
- def remove_citation(text):
26
  pattern = r"【\d+†\w+】"
27
  return re.sub(pattern, "πŸ“š", text)
28
 
29
- def predict(message, history):
 
 
 
 
 
 
 
 
30
  """
31
- Receives the new user message plus the entire conversation history
32
- from Gradio. If no thread_id is set, we create a new thread.
33
- Otherwise we pass the existing thread_id.
34
  """
35
- global THREAD_ID
36
-
37
- # debug print
38
- print("current history:", history)
39
-
40
- # If history is empty, this means that it is probably a new conversation and therefore the thread shall be reset
41
- if not history:
42
- THREAD_ID = None
43
-
44
- # 1) Decide if we are creating a new thread or continuing the old one
45
- if THREAD_ID is None:
46
- # No thread_id yet -> this is the first user message
47
- response = extractor_llm.invoke({"content": message})
48
- THREAD_ID = response.thread_id # store for subsequent calls
49
  else:
50
- # We already have a thread_id -> continue that same thread
51
- response = extractor_llm.invoke({"content": message, "thread_id": THREAD_ID})
52
-
53
- # 2) Extract the text output from the response
54
  output = response.return_values["output"]
55
- non_cited_output = remove_citation(output)
56
-
57
- # 3) Return the model's text to display in Gradio
58
- return non_cited_output
59
 
60
- # Create a Gradio ChatInterface using our predict function
61
- chat = gr.ChatInterface(
62
- fn=predict,
63
- title="Solution Specifier A",
64
- #description="Testing threaded conversation"
65
- )
66
- chat.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
+ import streamlit as st
4
  from dotenv import load_dotenv
 
 
 
 
5
  from langchain.agents.openai_assistant import OpenAIAssistantRunnable
 
6
 
7
+ # Load environment variables
8
+ load_dotenv()
9
+ api_key = os.getenv("OPENAI_API_KEY")
10
+ extractor_agent = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A")
11
 
12
+ # Create the assistant
 
13
  extractor_llm = OpenAIAssistantRunnable(
14
  assistant_id=extractor_agent,
15
  api_key=api_key,
16
  as_agent=True
17
  )
18
 
19
+ def remove_citation(text: str) -> str:
 
 
 
20
  pattern = r"【\d+†\w+】"
21
  return re.sub(pattern, "πŸ“š", text)
22
 
23
+ # Initialize session state for messages and thread_id
24
+ if "messages" not in st.session_state:
25
+ st.session_state["messages"] = []
26
+ if "thread_id" not in st.session_state:
27
+ st.session_state["thread_id"] = None
28
+
29
+ st.title("Solution Specifier A")
30
+
31
+ def predict(user_input: str) -> str:
32
  """
33
+ This function calls our OpenAIAssistantRunnable to get a response.
34
+ If we don't have a thread_id yet, we create a new thread on the first call.
35
+ Otherwise, we continue the existing thread.
36
  """
37
+ if st.session_state["thread_id"] is None:
38
+ response = extractor_llm.invoke({"content": user_input})
39
+ st.session_state["thread_id"] = response.thread_id
 
 
 
 
 
 
 
 
 
 
 
40
  else:
41
+ response = extractor_llm.invoke(
42
+ {"content": user_input, "thread_id": st.session_state["thread_id"]}
43
+ )
 
44
  output = response.return_values["output"]
45
+ return remove_citation(output)
 
 
 
46
 
47
+ # Display any existing messages (from a previous run or refresh)
48
+ for msg in st.session_state["messages"]:
49
+ if msg["role"] == "user":
50
+ with st.chat_message("user"):
51
+ st.write(msg["content"])
52
+ else:
53
+ with st.chat_message("assistant"):
54
+ st.write(msg["content"])
55
+
56
+ # Create the chat input widget at the bottom of the page
57
+ user_input = st.chat_input("Type your message here...")
58
+
59
+ # When the user hits ENTER on st.chat_input
60
+ if user_input:
61
+ # Add the user message to session state
62
+ st.session_state["messages"].append({"role": "user", "content": user_input})
63
+
64
+ # Display the user's message
65
+ with st.chat_message("user"):
66
+ st.write(user_input)
67
+
68
+ # Get the assistant's response
69
+ response_text = predict(user_input)
70
+
71
+ # Add the assistant response to session state
72
+ st.session_state["messages"].append({"role": "assistant", "content": response_text})
73
+
74
+ # Display the assistant's reply
75
+ with st.chat_message("assistant"):
76
+ st.write(response_text)