Geek7 commited on
Commit
4d94881
Β·
verified Β·
1 Parent(s): 045f54c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -29
app.py CHANGED
@@ -14,24 +14,24 @@ import torch
14
  import cv2
15
  from gradio_client import Client, file
16
 
 
17
  def image_gen(prompt):
18
  client = Client("KingNish/Image-Gen-Pro")
19
- return client.predict("Image Generation",None, prompt, api_name="/image_gen_pro")
20
 
 
21
  model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
22
-
23
  processor = LlavaProcessor.from_pretrained(model_id)
24
-
25
  model = LlavaForConditionalGeneration.from_pretrained(model_id)
26
  model.to("cpu")
27
 
28
-
29
  def llava(message, history):
30
  if message["files"]:
31
  image = message["files"][0]
32
  else:
33
  for hist in history:
34
- if type(hist[0])==tuple:
35
  image = hist[0][0]
36
 
37
  txt = message["text"]
@@ -43,12 +43,14 @@ def llava(message, history):
43
  inputs = processor(prompt, image, return_tensors="pt")
44
  return inputs
45
 
 
46
  def extract_text_from_webpage(html_content):
47
  soup = BeautifulSoup(html_content, 'html.parser')
48
  for tag in soup(["script", "style", "header", "footer"]):
49
  tag.extract()
50
  return soup.get_text(strip=True)
51
 
 
52
  def search(query):
53
  term = query
54
  start = 0
@@ -88,8 +90,8 @@ client_yi = InferenceClient("01-ai/Yi-1.5-34B-Chat")
88
  # Define the main chat function
89
  def respond(message, history):
90
  func_caller = []
91
-
92
  user_prompt = message
 
93
  # Handle image processing
94
  if message["files"]:
95
  inputs = llava(message, history)
@@ -101,9 +103,11 @@ def respond(message, history):
101
 
102
  buffer = ""
103
  for new_text in streamer:
104
- buffer += new_text
105
- yield buffer
 
106
  else:
 
107
  functions_metadata = [
108
  {"type": "function", "function": {"name": "web_search", "description": "Search query on google", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "web search query"}}, "required": ["query"]}}},
109
  {"type": "function", "function": {"name": "general_query", "description": "Reply general query of USER", "parameters": {"type": "object", "properties": {"prompt": {"type": "string", "description": "A detailed prompt"}}, "required": ["prompt"]}}},
@@ -120,45 +124,41 @@ def respond(message, history):
120
 
121
  response = client_gemma.chat_completion(func_caller, max_tokens=200)
122
  response = str(response)
 
 
123
  try:
124
  response = response[int(response.find("{")):int(response.rindex("</"))]
125
  except:
126
  response = response[int(response.find("{")):(int(response.rfind("}"))+1)]
127
- response = response.replace("\\n", "")
128
- response = response.replace("\\'", "'")
129
- response = response.replace('\\"', '"')
130
- response = response.replace('\\', '')
131
  print(f"\n{response}")
132
 
133
  try:
134
  json_data = json.loads(str(response))
135
  if json_data["name"] == "web_search":
136
  query = json_data["arguments"]["query"]
137
- # gr.Info("Searching Web")
138
  web_results = search(query)
139
- # gr.Info("Extracting relevant Info")
140
  web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
141
  messages = f"<|im_start|>system\n Hi πŸ‘‹, I am Nora,mini a helpful assistant.Ask me! I will do my best!! <|im_end|>"
142
  for msg in history:
143
  messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
144
  messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
145
- messages+=f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>web_result\n{web2}<|im_end|>\n<|im_start|>assistant\n"
146
  stream = client_mixtral.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
147
  output = ""
148
  for response in stream:
149
- if not response.token.text == "<|im_end|>":
150
  output += response.token.text
151
  yield output
152
  elif json_data["name"] == "image_generation":
153
  query = json_data["arguments"]["query"]
154
- gr.Info("Generating Image, Please wait 10 sec...")
155
  yield "Generating Image, Please wait 10 sec..."
156
  try:
157
  image = image_gen(f"{str(query)}")
158
  yield gr.Image(image[1])
159
  except:
160
  client_sd3 = InferenceClient("stabilityai/stable-diffusion-3-medium-diffusers")
161
- seed = random.randint(0,999999)
162
  image = client_sd3.text_to_image(query, negative_prompt=f"{seed}")
163
  yield gr.Image(image)
164
  elif json_data["name"] == "image_qna":
@@ -168,33 +168,35 @@ def respond(message, history):
168
 
169
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
170
  thread.start()
171
-
172
  buffer = ""
173
  for new_text in streamer:
174
- buffer += new_text
175
- yield buffer
 
176
  else:
177
  messages = f"<|im_start|>system\n πŸ‘‹, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|im_end|>"
178
  for msg in history:
179
  messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
180
  messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
181
- messages+=f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
182
  stream = client_yi.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
183
  output = ""
184
  for response in stream:
185
- if not response.token.text == "<|endoftext|>":
186
  output += response.token.text
187
  yield output
188
  except:
189
- messages = f"<|start_header_id|>system\nHi πŸ‘‹, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|end_header_id|>"
 
190
  for msg in history:
191
- messages += f"\n<|start_header_id|>user\n{str(msg[0])}<|end_header_id|>"
192
- messages += f"\n<|start_header_id|>assistant\n{str(msg[1])}<|end_header_id|>"
193
- messages+=f"\n<|start_header_id|>user\n{message_text}<|end_header_id|>\n<|start_header_id|>assistant\n"
194
  stream = client_llama.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
195
  output = ""
196
  for response in stream:
197
- if not response.token.text == "<|eot_id|>":
198
  output += response.token.text
199
  yield output
200
 
@@ -205,6 +207,9 @@ demo = gr.ChatInterface(
205
  textbox=gr.MultimodalTextbox(),
206
  multimodal=True,
207
  concurrency_limit=200,
208
- cache_examples=False,css="footer{display:none !important}"
 
209
  )
 
 
210
  demo.launch()
 
14
  import cv2
15
  from gradio_client import Client, file
16
 
17
+ # Function to generate an image using another model
18
  def image_gen(prompt):
19
  client = Client("KingNish/Image-Gen-Pro")
20
+ return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")
21
 
22
+ # Load the processor and model for image-based QnA (LLaVA model)
23
  model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
 
24
  processor = LlavaProcessor.from_pretrained(model_id)
 
25
  model = LlavaForConditionalGeneration.from_pretrained(model_id)
26
  model.to("cpu")
27
 
28
+ # Function to process images with text input
29
  def llava(message, history):
30
  if message["files"]:
31
  image = message["files"][0]
32
  else:
33
  for hist in history:
34
+ if type(hist[0]) == tuple:
35
  image = hist[0][0]
36
 
37
  txt = message["text"]
 
43
  inputs = processor(prompt, image, return_tensors="pt")
44
  return inputs
45
 
46
+ # Helper function to extract text from a webpage
47
  def extract_text_from_webpage(html_content):
48
  soup = BeautifulSoup(html_content, 'html.parser')
49
  for tag in soup(["script", "style", "header", "footer"]):
50
  tag.extract()
51
  return soup.get_text(strip=True)
52
 
53
+ # Function to search the web using Google
54
  def search(query):
55
  term = query
56
  start = 0
 
90
  # Define the main chat function
91
  def respond(message, history):
92
  func_caller = []
 
93
  user_prompt = message
94
+
95
  # Handle image processing
96
  if message["files"]:
97
  inputs = llava(message, history)
 
103
 
104
  buffer = ""
105
  for new_text in streamer:
106
+ if new_text not in ["<|im_end|>", "<|endoftext|>"]: # Ignore special tokens
107
+ buffer += new_text
108
+ yield buffer
109
  else:
110
+ # Functions metadata for invoking different models or functions
111
  functions_metadata = [
112
  {"type": "function", "function": {"name": "web_search", "description": "Search query on google", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "web search query"}}, "required": ["query"]}}},
113
  {"type": "function", "function": {"name": "general_query", "description": "Reply general query of USER", "parameters": {"type": "object", "properties": {"prompt": {"type": "string", "description": "A detailed prompt"}}, "required": ["prompt"]}}},
 
124
 
125
  response = client_gemma.chat_completion(func_caller, max_tokens=200)
126
  response = str(response)
127
+
128
+ # Filtering and processing response
129
  try:
130
  response = response[int(response.find("{")):int(response.rindex("</"))]
131
  except:
132
  response = response[int(response.find("{")):(int(response.rfind("}"))+1)]
133
+ response = response.replace("\\n", "").replace("\\'", "'").replace('\\"', '"').replace('\\', '')
 
 
 
134
  print(f"\n{response}")
135
 
136
  try:
137
  json_data = json.loads(str(response))
138
  if json_data["name"] == "web_search":
139
  query = json_data["arguments"]["query"]
 
140
  web_results = search(query)
 
141
  web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
142
  messages = f"<|im_start|>system\n Hi πŸ‘‹, I am Nora,mini a helpful assistant.Ask me! I will do my best!! <|im_end|>"
143
  for msg in history:
144
  messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
145
  messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
146
+ messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>web_result\n{web2}<|im_end|>\n<|im_start|>assistant\n"
147
  stream = client_mixtral.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
148
  output = ""
149
  for response in stream:
150
+ if not response.token.text in ["<|im_end|>", "<|endoftext|>"]: # Exclude special tokens
151
  output += response.token.text
152
  yield output
153
  elif json_data["name"] == "image_generation":
154
  query = json_data["arguments"]["query"]
 
155
  yield "Generating Image, Please wait 10 sec..."
156
  try:
157
  image = image_gen(f"{str(query)}")
158
  yield gr.Image(image[1])
159
  except:
160
  client_sd3 = InferenceClient("stabilityai/stable-diffusion-3-medium-diffusers")
161
+ seed = random.randint(0, 999999)
162
  image = client_sd3.text_to_image(query, negative_prompt=f"{seed}")
163
  yield gr.Image(image)
164
  elif json_data["name"] == "image_qna":
 
168
 
169
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
170
  thread.start()
171
+
172
  buffer = ""
173
  for new_text in streamer:
174
+ if new_text not in ["<|im_end|>", "<|endoftext|>"]: # Ignore special tokens
175
+ buffer += new_text
176
+ yield buffer
177
  else:
178
  messages = f"<|im_start|>system\n πŸ‘‹, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|im_end|>"
179
  for msg in history:
180
  messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
181
  messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
182
+ messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
183
  stream = client_yi.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
184
  output = ""
185
  for response in stream:
186
+ if response.token.text not in ["<|im_end|>", "<|endoftext|>"]: # Ignore special tokens
187
  output += response.token.text
188
  yield output
189
  except:
190
+ # Handle the case where JSON parsing or function calling fails
191
+ messages = f"<|im_start|>system\nHi πŸ‘‹, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|im_end|>"
192
  for msg in history:
193
+ messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
194
+ messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
195
+ messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
196
  stream = client_llama.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
197
  output = ""
198
  for response in stream:
199
+ if response.token.text not in ["<|eot_id|>", "<|im_end|>"]: # Ignore special tokens
200
  output += response.token.text
201
  yield output
202
 
 
207
  textbox=gr.MultimodalTextbox(),
208
  multimodal=True,
209
  concurrency_limit=200,
210
+ cache_examples=False,
211
+ css="footer{display:none !important}"
212
  )
213
+
214
+ # Launch the Gradio app
215
  demo.launch()