File size: 2,324 Bytes
f0a527b
 
5cfff76
1a7ff48
 
 
 
2636944
 
1a7ff48
264a47b
 
 
 
 
1a7ff48
 
2636944
 
 
f0a527b
c192923
f0a527b
 
 
 
 
 
 
 
 
1a7ff48
2636944
03680b6
f7993f7
 
f0a527b
2636944
 
f0a527b
 
 
 
 
 
 
 
 
 
45bb735
 
 
 
 
 
 
 
 
f0a527b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from huggingface_hub import InferenceClient
import os
import faiss
from transformers import pipeline
from sentence_transformers import SentenceTransformer


# Documents converted into FAISS vectors
documents = [
    "The class starts at 2PM Wednesday.",
    "Python is our main programming language.",
    "Our university is located in Szeged.",
    "We are making things with RAG, Rasa and LLMs.",
    "Gabor Toth is the author of this chatbot."
]
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
document_embeddings = embedding_model.encode(documents)
index = faiss.IndexFlatL2(document_embeddings.shape[1])
index.add(document_embeddings)

client = InferenceClient("meta-llama/Llama-3.2-3B-Instruct")

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    
    # Get relevant document
    query_embedding = embedding_model.encode([message])
    distances, indices = index.search(query_embedding, k=2)
    relevant_document = documents[indices[0][0]], documents[indices[0][1]]

    # Set prompt
    messages = [{"role": "system", "content": system_message},{"role": "system", "content": f"context: {relevant_document}"}]
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        response += token
        yield response

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    demo.launch()