Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
import spaces | |
from huggingface_hub import InferenceClient | |
from langchain_community.embeddings import HuggingFaceInstructEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from langchain.prompts import PromptTemplate | |
# Configure ZeroGPU client | |
client = InferenceClient("meta-llama/Llama-3.2-3B-Instruct") | |
# Initialize embeddings | |
embeddings = HuggingFaceInstructEmbeddings( | |
model_name="hkunlp/instructor-base", | |
model_kwargs={"device": "cpu"} # Use CPU for Spaces | |
) | |
# Load the persisted database | |
db = Chroma( | |
persist_directory="db", | |
embedding_function=embeddings | |
) | |
# Prompt templates | |
DEFAULT_SYSTEM_PROMPT = """ | |
You are a ROS2 expert assistant. Based on the information provided in the context, answer questions | |
accurately and concisely. If the information is not in the context, acknowledge that you don't know. | |
""".strip() | |
def respond( | |
message, | |
history, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
try: | |
# Retrieve relevant context | |
docs = db.similarity_search(message, k=2) | |
context = "\n".join([doc.page_content for doc in docs]) | |
# Build messages | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
# Add context to the user message | |
augmented_message = f"Context: {context}\n\nQuestion: {message}" | |
messages.append({"role": "user", "content": augmented_message}) | |
# Stream the response | |
response = "" | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = message.choices[0].delta.content | |
response += token | |
yield response | |
except Exception as e: | |
yield f"An error occurred: {str(e)}" | |
# Create Gradio interface | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value=DEFAULT_SYSTEM_PROMPT, | |
label="System message" | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=2048, | |
value=500, | |
step=1, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=0.1, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
], | |
title="ROS2 Expert Assistant", | |
description="Ask questions about ROS2, navigation, and robotics. I'll answer based on my knowledge base.", | |
) | |
if __name__ == "__main__": | |
demo.launch() |