File size: 3,044 Bytes
d30c02a
 
96d766a
d30c02a
96d766a
d30c02a
 
 
96d766a
d30c02a
 
96d766a
d30c02a
 
 
 
 
 
 
 
 
 
 
96d766a
d30c02a
 
 
 
 
 
 
96d766a
 
d30c02a
96d766a
 
 
 
 
d30c02a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96d766a
d30c02a
96d766a
 
 
d30c02a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96d766a
 
 
 
 
d30c02a
96d766a
 
d30c02a
 
96d766a
 
 
d30c02a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import torch
import gradio as gr
import spaces
from huggingface_hub import InferenceClient
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.prompts import PromptTemplate

# Configure ZeroGPU client
client = InferenceClient("meta-llama/Llama-3.2-3B-Instruct")

# Initialize embeddings
embeddings = HuggingFaceInstructEmbeddings(
    model_name="hkunlp/instructor-base",
    model_kwargs={"device": "cpu"}  # Use CPU for Spaces
)

# Load the persisted database
db = Chroma(
    persist_directory="db",
    embedding_function=embeddings
)

# Prompt templates
DEFAULT_SYSTEM_PROMPT = """
You are a ROS2 expert assistant. Based on the information provided in the context, answer questions 
accurately and concisely. If the information is not in the context, acknowledge that you don't know.
""".strip()

@spaces.GPU(duration=60)
def respond(
    message,
    history,
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    try:
        # Retrieve relevant context
        docs = db.similarity_search(message, k=2)
        context = "\n".join([doc.page_content for doc in docs])
        
        # Build messages
        messages = [{"role": "system", "content": system_message}]
        for val in history:
            if val[0]:
                messages.append({"role": "user", "content": val[0]})
            if val[1]:
                messages.append({"role": "assistant", "content": val[1]})
        
        # Add context to the user message
        augmented_message = f"Context: {context}\n\nQuestion: {message}"
        messages.append({"role": "user", "content": augmented_message})
        
        # Stream the response
        response = ""
        for message in client.chat_completion(
            messages,
            max_tokens=max_tokens,
            stream=True,
            temperature=temperature,
            top_p=top_p,
        ):
            token = message.choices[0].delta.content
            response += token
            yield response
            
    except Exception as e:
        yield f"An error occurred: {str(e)}"

# Create Gradio interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value=DEFAULT_SYSTEM_PROMPT,
            label="System message"
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=500,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.1,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
    ],
    title="ROS2 Expert Assistant",
    description="Ask questions about ROS2, navigation, and robotics. I'll answer based on my knowledge base.",
)

if __name__ == "__main__":
    demo.launch()