File size: 2,135 Bytes
a828de0
 
 
 
e541ad6
1734d74
a7cca91
1734d74
a828de0
e541ad6
a828de0
a7cca91
bfec1c8
1734d74
 
 
 
 
 
 
 
 
 
bfec1c8
a828de0
 
 
e541ad6
a7cca91
1734d74
 
 
 
 
 
 
f7e68f3
 
66a672c
f7e68f3
66a672c
 
 
1734d74
 
 
 
 
 
 
 
 
 
 
a828de0
1734d74
 
a828de0
e541ad6
1734d74
 
 
 
f7e68f3
1734d74
a7cca91
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import re
from dotenv import load_dotenv
load_dotenv()

import gradio as gr

from langchain.agents.openai_assistant import OpenAIAssistantRunnable
from langchain.schema import HumanMessage, AIMessage

api_key = os.getenv('OPENAI_API_KEY')
extractor_agent = os.getenv('ASSISTANT_ID_SOLUTION_SPECIFIER_A')

# Create the assistant. By default, we don't specify a thread_id,
# so the first call that doesn't pass one will create a new thread.
extractor_llm = OpenAIAssistantRunnable(
    assistant_id=extractor_agent,
    api_key=api_key,
    as_agent=True
)

# We will store thread_id globally or in a session variable.
THREAD_ID = None

def remove_citation(text):
    pattern = r"【\d+†\w+】"
    return re.sub(pattern, "πŸ“š", text)

def predict(message, history):
    """
    Receives the new user message plus the entire conversation history 
    from Gradio. If no thread_id is set, we create a new thread. 
    Otherwise we pass the existing thread_id.
    """
    global THREAD_ID
    
    # debug print
    print("current history:", history)
    
    # If history is empty, this means that it is probably a new conversation and therefore the thread shall be reset
    if not history:
        THREAD_ID = None
    
    # 1) Decide if we are creating a new thread or continuing the old one
    if THREAD_ID is None:
        # No thread_id yet -> this is the first user message
        response = extractor_llm.invoke({"content": message})
        THREAD_ID = response.thread_id  # store for subsequent calls
    else:
        # We already have a thread_id -> continue that same thread
        response = extractor_llm.invoke({"content": message, "thread_id": THREAD_ID})
    
    # 2) Extract the text output from the response
    output = response.return_values["output"]
    non_cited_output = remove_citation(output)
    
    # 3) Return the model's text to display in Gradio
    return non_cited_output

# Create a Gradio ChatInterface using our predict function
chat = gr.ChatInterface(
    fn=predict, 
    title="Solution Specifier A", 
    #description="Testing threaded conversation"
)
chat.launch(share=True)