mannadamay12 commited on
Commit
837e37f
·
verified ·
1 Parent(s): 882d9bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -59
app.py CHANGED
@@ -16,6 +16,13 @@ If the information is not in the context, respond with "I don't find that inform
16
  Keep responses to 1-2 lines maximum.
17
  """.strip()
18
 
 
 
 
 
 
 
 
19
  def generate_prompt(context: str, question: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
20
  return f"""
21
  [INST] <<SYS>>
@@ -50,14 +57,10 @@ def initialize_model():
50
 
51
  return model, tokenizer
52
 
53
- class CustomTextStreamer(TextStreamer):
54
- def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
55
- super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
56
- self.output_text = ""
57
-
58
- def put(self, value):
59
- self.output_text += value
60
- super().put(value)
61
 
62
  @spaces.GPU
63
  def respond(message, history, system_message, max_tokens, temperature, top_p):
@@ -94,77 +97,97 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
94
 
95
  except Exception as e:
96
  yield f"An error occurred: {str(e)}"
97
- # def respond(message, history, system_message, max_tokens, temperature, top_p):
98
- # try:
99
- # model, tokenizer = initialize_model()
100
-
101
- # # Get relevant context from the database
102
- # retriever = db.as_retriever(search_kwargs={"k": 2})
103
- # docs = retriever.get_relevant_documents(message)
104
- # context = "\n".join([doc.page_content for doc in docs])
105
-
106
- # # Generate the complete prompt
107
- # prompt = generate_prompt(context=context, question=message, system_prompt=system_message)
108
-
109
- # # Set up the streamer
110
- # streamer = CustomTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
111
-
112
- # # Set up the pipeline
113
- # text_pipeline = pipeline(
114
- # "text-generation",
115
- # model=model,
116
- # tokenizer=tokenizer,
117
- # max_new_tokens=max_tokens,
118
- # temperature=temperature,
119
- # top_p=top_p,
120
- # repetition_penalty=1.15,
121
- # streamer=streamer,
122
- # )
123
 
124
- # # Generate response
125
- # _ = text_pipeline(prompt, max_new_tokens=max_tokens)
 
 
 
 
 
 
 
 
 
126
 
127
- # # Return only the generated response
128
- # yield streamer.output_text.strip()
 
129
 
130
- # except Exception as e:
131
- # yield f"An error occurred: {str(e)}"
132
-
133
- # Create Gradio interface
134
- demo = gr.ChatInterface(
135
- respond,
136
- additional_inputs=[
137
- gr.Textbox(
138
  value=DEFAULT_SYSTEM_PROMPT,
139
  label="System Message",
140
- lines=3,
141
- visible=False
142
- ),
143
- gr.Slider(
144
  minimum=1,
145
  maximum=2048,
146
  value=500,
147
  step=1,
148
  label="Max new tokens"
149
- ),
150
- gr.Slider(
151
  minimum=0.1,
152
  maximum=4.0,
153
  value=0.1,
154
  step=0.1,
155
  label="Temperature"
156
- ),
157
- gr.Slider(
158
  minimum=0.1,
159
  maximum=1.0,
160
  value=0.95,
161
  step=0.05,
162
  label="Top-p"
163
- ),
164
- ],
165
- title="ROS2 Expert Assistant",
166
- description="Ask questions about ROS2, navigation, and robotics. I'll provide concise answers based on the available documentation.",
167
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  if __name__ == "__main__":
170
  demo.launch(share=True)
 
16
  Keep responses to 1-2 lines maximum.
17
  """.strip()
18
 
19
+ # Pre-populated questions
20
+ PREDEFINED_QUESTIONS = [
21
+ "Select a question...",
22
+ "Tell me how can I navigate to a specific pose - include replanning aspects in your answer.",
23
+ "Can you provide me with code for this task?"
24
+ ]
25
+
26
  def generate_prompt(context: str, question: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
27
  return f"""
28
  [INST] <<SYS>>
 
57
 
58
  return model, tokenizer
59
 
60
+ def question_selected(question):
61
+ if question == "Select a question...":
62
+ return ""
63
+ return question
 
 
 
 
64
 
65
  @spaces.GPU
66
  def respond(message, history, system_message, max_tokens, temperature, top_p):
 
97
 
98
  except Exception as e:
99
  yield f"An error occurred: {str(e)}"
100
+
101
+ # Create the Gradio interface
102
+ with gr.Blocks(title="ROS2 Expert Assistant") as demo:
103
+ gr.Markdown("# ROS2 Expert Assistant")
104
+ gr.Markdown("Ask questions about ROS2, navigation, and robotics. I'll provide concise answers based on the available documentation.")
105
+
106
+ with gr.Row():
107
+ # Dropdown for predefined questions
108
+ question_dropdown = gr.Dropdown(
109
+ choices=PREDEFINED_QUESTIONS,
110
+ value="Select a question...",
111
+ label="Pre-defined Questions"
112
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ with gr.Row():
115
+ # Chat interface
116
+ chatbot = gr.Chatbot()
117
+
118
+ with gr.Row():
119
+ # Message input
120
+ msg = gr.Textbox(
121
+ label="Your Question",
122
+ placeholder="Type your question here or select one from the dropdown above...",
123
+ lines=2
124
+ )
125
 
126
+ with gr.Row():
127
+ submit = gr.Button("Submit")
128
+ clear = gr.Button("Clear")
129
 
130
+ with gr.Accordion("Advanced Settings", open=False):
131
+ system_message = gr.Textbox(
 
 
 
 
 
 
132
  value=DEFAULT_SYSTEM_PROMPT,
133
  label="System Message",
134
+ lines=3
135
+ )
136
+ max_tokens = gr.Slider(
 
137
  minimum=1,
138
  maximum=2048,
139
  value=500,
140
  step=1,
141
  label="Max new tokens"
142
+ )
143
+ temperature = gr.Slider(
144
  minimum=0.1,
145
  maximum=4.0,
146
  value=0.1,
147
  step=0.1,
148
  label="Temperature"
149
+ )
150
+ top_p = gr.Slider(
151
  minimum=0.1,
152
  maximum=1.0,
153
  value=0.95,
154
  step=0.05,
155
  label="Top-p"
156
+ )
157
+
158
+ # Event handlers
159
+ question_dropdown.change(
160
+ question_selected,
161
+ inputs=[question_dropdown],
162
+ outputs=[msg]
163
+ )
164
+
165
+ submit.click(
166
+ respond,
167
+ inputs=[
168
+ msg,
169
+ chatbot,
170
+ system_message,
171
+ max_tokens,
172
+ temperature,
173
+ top_p
174
+ ],
175
+ outputs=[chatbot]
176
+ )
177
+
178
+ clear.click(lambda: None, None, chatbot, queue=False)
179
+ msg.submit(
180
+ respond,
181
+ inputs=[
182
+ msg,
183
+ chatbot,
184
+ system_message,
185
+ max_tokens,
186
+ temperature,
187
+ top_p
188
+ ],
189
+ outputs=[chatbot]
190
+ )
191
 
192
  if __name__ == "__main__":
193
  demo.launch(share=True)