File size: 4,568 Bytes
cf5e123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr

from langchain_astradb import AstraDBVectorStore

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, AIMessage, HumanMessage
from langchain_openai import OpenAIEmbeddings, ChatOpenAI

import os

prompt_template = os.environ.get("PROMPT_TEMPLATE")

prompt = ChatPromptTemplate.from_messages([('system', prompt_template)])

AI = False

def ai_setup():
    global llm, prompt_chain
    llm = ChatOpenAI(model = "gpt-4o", temperature=0.8)
    
    if AI:
        embedding = OpenAIEmbeddings()
        vstore = AstraDBVectorStore(
            embedding=embedding,
            collection_name=os.environ.get("ASTRA_DB_COLLECTION"),
            token=os.environ.get("ASTRA_DB_APPLICATION_TOKEN"),
            api_endpoint=os.environ.get("ASTRA_DB_API_ENDPOINT"),
        )

        retriever = vstore.as_retriever(search_kwargs={'k': 10})
    else:
        retriever = RunnableLambda(just_read)

    prompt_chain = (
        {"context": retriever, "question": RunnablePassthrough()}
        | RunnableLambda(format_context) 
        | prompt
        # | llm
        # | StrOutputParser()
    )

def group_and_sort(documents):
    grouped = {}
    for document in documents:
        title = document.metadata["Title"]
        docs = grouped.get(title, [])
        grouped[title] = docs
        
        docs.append((document.page_content, document.metadata["range"]))
    
    for title, values in grouped.items():
        values.sort(key=lambda doc:doc[1][0])

    for title in grouped:
        text = ''
        prev_last = 0
        for fragment, (start, last) in grouped[title]:
            if   start < prev_last:
                text += fragment[prev_last-start:]
            elif start == prev_last:
                text += fragment
            else:
                text += ' [...] '
                text += fragment
            prev_last = last

        grouped[title] = text
                
    return grouped
        
def format_context(pipeline_state):
    """Print the state passed between Runnables in a langchain and pass it on"""

    context = ''
    documents = group_and_sort(pipeline_state["context"])
    for title, text in documents.items():
        context += f"\nTitle: {title}\n"
        context += text
        context += '\n\n---\n'

    pipeline_state["context"] = context
    return pipeline_state

def just_read(pipeline_state):
    fname = "docs.pickle"
    import pickle
    
    return pickle.load(open(fname, "rb"))

def new_state():
    return gr.State({
        "system": None,
    })

def chat(message, history, state):    
    if not history:
        system_prompt = prompt_chain.invoke(message)
        system_prompt = system_prompt.messages[0]
        state["system"] = system_prompt
    else:
        system_prompt = state["system"]
        
    messages = [system_prompt]
    for human, ai in history:
        messages.append(HumanMessage(human))
        messages.append(AIMessage(ai))
    messages.append(HumanMessage(message))
    
    all = ''
    for response in llm.stream(messages):
        all += response.content
        yield all

def gr_main():
    theme = gr.Theme.from_hub("freddyaboulton/[email protected]")
    theme.set(
        color_accent_soft="#818eb6",            # ChatBot.svelte / .message-row.panel.user-row
        background_fill_secondary="#6272a4",    # ChatBot.svelte / .message-row.panel.bot-row
        button_primary_text_color="*button_secondary_text_color",
        button_primary_background_fill="*button_secondary_background_fill")

    with gr.Blocks(
        title="Sherlock Holmes stories",
        fill_height=True,
        theme=theme
        ) as app:
            state = new_state()
            gr.ChatInterface(
                chat,
                chatbot=gr.Chatbot(show_label=False, render=False, scale=1),
                title="Sherlock Holmes stories",
                examples=[
                    ["I arrived late last night and found a dead goose in my bed"],
                    ["Help please sir. I'm about to get married, to the most lovely lady,"
                    "and I just received a letter threatening me to make public some things"
                    "of my past I'd rather keep quiet, unless I don't marry"],
                ],
            additional_inputs=[state])
    app.launch(show_api=False)
if __name__ == "__main__":
    ai_setup()
    gr_main()