mannadamay12 commited on
Commit
85b8a02
·
verified ·
1 Parent(s): b005487

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -32
app.py CHANGED
@@ -3,20 +3,25 @@ import torch
3
  import gradio as gr
4
  import spaces
5
  from huggingface_hub import InferenceClient
6
- from langchain_community.embeddings import HuggingFaceInstructEmbeddings
7
- from langchain_community.vectorstores import Chroma
8
  from langchain.prompts import PromptTemplate
9
 
10
- # Configure ZeroGPU client
11
- client = InferenceClient("meta-llama/Llama-3.2-3B-Instruct")
 
 
 
 
12
 
13
- # Initialize embeddings
14
- embeddings = HuggingFaceInstructEmbeddings(
15
- model_name="hkunlp/instructor-base",
16
- model_kwargs={"device": "cpu"} # Use CPU for Spaces
 
17
  )
18
 
19
- # Load the persisted database
20
  db = Chroma(
21
  persist_directory="db",
22
  embedding_function=embeddings
@@ -24,11 +29,31 @@ db = Chroma(
24
 
25
  # Prompt templates
26
  DEFAULT_SYSTEM_PROMPT = """
27
- You are a ROS2 expert assistant. Based on the information provided in the context, answer questions
28
- accurately and concisely. If the information is not in the context, acknowledge that you don't know.
 
 
 
 
 
 
 
 
 
29
  """.strip()
30
 
31
- @spaces.GPU(duration=60)
 
 
 
 
 
 
 
 
 
 
 
32
  def respond(
33
  message,
34
  history,
@@ -37,34 +62,28 @@ def respond(
37
  temperature,
38
  top_p,
39
  ):
 
40
  try:
41
- # Retrieve relevant context
42
  docs = db.similarity_search(message, k=2)
43
  context = "\n".join([doc.page_content for doc in docs])
44
 
45
- # Build messages
46
- messages = [{"role": "system", "content": system_message}]
47
- for val in history:
48
- if val[0]:
49
- messages.append({"role": "user", "content": val[0]})
50
- if val[1]:
51
- messages.append({"role": "assistant", "content": val[1]})
52
-
53
- # Add context to the user message
54
- augmented_message = f"Context: {context}\n\nQuestion: {message}"
55
- messages.append({"role": "user", "content": augmented_message})
56
 
57
- # Stream the response
58
  response = ""
59
- for message in client.chat_completion(
60
- messages,
61
- max_tokens=max_tokens,
62
  stream=True,
63
  temperature=temperature,
64
  top_p=top_p,
65
  ):
66
- token = message.choices[0].delta.content
67
- response += token
68
  yield response
69
 
70
  except Exception as e:
@@ -76,7 +95,9 @@ demo = gr.ChatInterface(
76
  additional_inputs=[
77
  gr.Textbox(
78
  value=DEFAULT_SYSTEM_PROMPT,
79
- label="System message"
 
 
80
  ),
81
  gr.Slider(
82
  minimum=1,
@@ -101,7 +122,7 @@ demo = gr.ChatInterface(
101
  ),
102
  ],
103
  title="ROS2 Expert Assistant",
104
- description="Ask questions about ROS2, navigation, and robotics. I'll answer based on my knowledge base.",
105
  )
106
 
107
  if __name__ == "__main__":
 
3
  import gradio as gr
4
  import spaces
5
  from huggingface_hub import InferenceClient
6
+ from langchain.embeddings import HuggingFaceEmbeddings
7
+ from langchain.vectorstores import Chroma
8
  from langchain.prompts import PromptTemplate
9
 
10
+ # Verify PyTorch version compatibility
11
+ TORCH_VERSION = torch.__version__
12
+ SUPPORTED_TORCH_VERSIONS = ['2.0.1', '2.1.2', '2.2.2', '2.4.0']
13
+ if TORCH_VERSION.rsplit('+')[0] not in SUPPORTED_TORCH_VERSIONS:
14
+ print(f"Warning: Current PyTorch version {TORCH_VERSION} may not be compatible with ZeroGPU. "
15
+ f"Supported versions are: {', '.join(SUPPORTED_TORCH_VERSIONS)}")
16
 
17
+ # Initialize components outside of GPU scope
18
+ client = InferenceClient("meta-llama/Llama-3.2-3B-Instruct")
19
+ embeddings = HuggingFaceEmbeddings(
20
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
21
+ model_kwargs={"device": "cpu"} # Keep embeddings on CPU
22
  )
23
 
24
+ # Load database
25
  db = Chroma(
26
  persist_directory="db",
27
  embedding_function=embeddings
 
29
 
30
  # Prompt templates
31
  DEFAULT_SYSTEM_PROMPT = """
32
+ Based on the information in this document provided in context, answer the question as accurately as possible in 1 or 2 lines. If the information is not in the context,
33
+ respond with "I don't know" or a similar acknowledgment that the answer is not available.
34
+ """.strip()
35
+
36
+ def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
37
+ return f"""
38
+ [INST] <<SYS>>
39
+ {system_prompt}
40
+ <</SYS>>
41
+
42
+ {prompt} [/INST]
43
  """.strip()
44
 
45
+ template = generate_prompt(
46
+ """
47
+ {context}
48
+
49
+ Question: {question}
50
+ """,
51
+ system_prompt="Use the following pieces of context to answer the question at the end. Do not provide commentary or elaboration more than 1 or 2 lines.?"
52
+ )
53
+
54
+ prompt_template = PromptTemplate(template=template, input_variables=["context", "question"])
55
+
56
+ @spaces.GPU(duration=30) # Reduced duration for faster queue priority
57
  def respond(
58
  message,
59
  history,
 
62
  temperature,
63
  top_p,
64
  ):
65
+ """GPU-accelerated response generation"""
66
  try:
67
+ # Retrieve context (CPU operation)
68
  docs = db.similarity_search(message, k=2)
69
  context = "\n".join([doc.page_content for doc in docs])
70
 
71
+ # Format prompt
72
+ formatted_prompt = prompt_template.format(
73
+ context=context,
74
+ question=message
75
+ )
 
 
 
 
 
 
76
 
77
+ # Stream response (GPU operation)
78
  response = ""
79
+ for message in client.text_generation(
80
+ prompt=formatted_prompt,
81
+ max_new_tokens=max_tokens,
82
  stream=True,
83
  temperature=temperature,
84
  top_p=top_p,
85
  ):
86
+ response += message
 
87
  yield response
88
 
89
  except Exception as e:
 
95
  additional_inputs=[
96
  gr.Textbox(
97
  value=DEFAULT_SYSTEM_PROMPT,
98
+ label="System Message",
99
+ lines=3,
100
+ visible=False
101
  ),
102
  gr.Slider(
103
  minimum=1,
 
122
  ),
123
  ],
124
  title="ROS2 Expert Assistant",
125
+ description="Ask questions about ROS2, navigation, and robotics. I'll provide concise answers based on the available documentation.",
126
  )
127
 
128
  if __name__ == "__main__":