Update app.py
Browse files
app.py
CHANGED
@@ -14,24 +14,24 @@ import torch
|
|
14 |
import cv2
|
15 |
from gradio_client import Client, file
|
16 |
|
|
|
17 |
def image_gen(prompt):
|
18 |
client = Client("KingNish/Image-Gen-Pro")
|
19 |
-
return client.predict("Image Generation",None, prompt, api_name="/image_gen_pro")
|
20 |
|
|
|
21 |
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
|
22 |
-
|
23 |
processor = LlavaProcessor.from_pretrained(model_id)
|
24 |
-
|
25 |
model = LlavaForConditionalGeneration.from_pretrained(model_id)
|
26 |
model.to("cpu")
|
27 |
|
28 |
-
|
29 |
def llava(message, history):
|
30 |
if message["files"]:
|
31 |
image = message["files"][0]
|
32 |
else:
|
33 |
for hist in history:
|
34 |
-
if type(hist[0])==tuple:
|
35 |
image = hist[0][0]
|
36 |
|
37 |
txt = message["text"]
|
@@ -43,12 +43,14 @@ def llava(message, history):
|
|
43 |
inputs = processor(prompt, image, return_tensors="pt")
|
44 |
return inputs
|
45 |
|
|
|
46 |
def extract_text_from_webpage(html_content):
|
47 |
soup = BeautifulSoup(html_content, 'html.parser')
|
48 |
for tag in soup(["script", "style", "header", "footer"]):
|
49 |
tag.extract()
|
50 |
return soup.get_text(strip=True)
|
51 |
|
|
|
52 |
def search(query):
|
53 |
term = query
|
54 |
start = 0
|
@@ -88,8 +90,8 @@ client_yi = InferenceClient("01-ai/Yi-1.5-34B-Chat")
|
|
88 |
# Define the main chat function
|
89 |
def respond(message, history):
|
90 |
func_caller = []
|
91 |
-
|
92 |
user_prompt = message
|
|
|
93 |
# Handle image processing
|
94 |
if message["files"]:
|
95 |
inputs = llava(message, history)
|
@@ -101,9 +103,11 @@ def respond(message, history):
|
|
101 |
|
102 |
buffer = ""
|
103 |
for new_text in streamer:
|
104 |
-
|
105 |
-
|
|
|
106 |
else:
|
|
|
107 |
functions_metadata = [
|
108 |
{"type": "function", "function": {"name": "web_search", "description": "Search query on google", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "web search query"}}, "required": ["query"]}}},
|
109 |
{"type": "function", "function": {"name": "general_query", "description": "Reply general query of USER", "parameters": {"type": "object", "properties": {"prompt": {"type": "string", "description": "A detailed prompt"}}, "required": ["prompt"]}}},
|
@@ -120,45 +124,41 @@ def respond(message, history):
|
|
120 |
|
121 |
response = client_gemma.chat_completion(func_caller, max_tokens=200)
|
122 |
response = str(response)
|
|
|
|
|
123 |
try:
|
124 |
response = response[int(response.find("{")):int(response.rindex("</"))]
|
125 |
except:
|
126 |
response = response[int(response.find("{")):(int(response.rfind("}"))+1)]
|
127 |
-
response = response.replace("\\n", "")
|
128 |
-
response = response.replace("\\'", "'")
|
129 |
-
response = response.replace('\\"', '"')
|
130 |
-
response = response.replace('\\', '')
|
131 |
print(f"\n{response}")
|
132 |
|
133 |
try:
|
134 |
json_data = json.loads(str(response))
|
135 |
if json_data["name"] == "web_search":
|
136 |
query = json_data["arguments"]["query"]
|
137 |
-
# gr.Info("Searching Web")
|
138 |
web_results = search(query)
|
139 |
-
# gr.Info("Extracting relevant Info")
|
140 |
web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
|
141 |
messages = f"<|im_start|>system\n Hi π, I am Nora,mini a helpful assistant.Ask me! I will do my best!! <|im_end|>"
|
142 |
for msg in history:
|
143 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
144 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
145 |
-
messages+=f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>web_result\n{web2}<|im_end|>\n<|im_start|>assistant\n"
|
146 |
stream = client_mixtral.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
147 |
output = ""
|
148 |
for response in stream:
|
149 |
-
if not response.token.text
|
150 |
output += response.token.text
|
151 |
yield output
|
152 |
elif json_data["name"] == "image_generation":
|
153 |
query = json_data["arguments"]["query"]
|
154 |
-
gr.Info("Generating Image, Please wait 10 sec...")
|
155 |
yield "Generating Image, Please wait 10 sec..."
|
156 |
try:
|
157 |
image = image_gen(f"{str(query)}")
|
158 |
yield gr.Image(image[1])
|
159 |
except:
|
160 |
client_sd3 = InferenceClient("stabilityai/stable-diffusion-3-medium-diffusers")
|
161 |
-
seed = random.randint(0,999999)
|
162 |
image = client_sd3.text_to_image(query, negative_prompt=f"{seed}")
|
163 |
yield gr.Image(image)
|
164 |
elif json_data["name"] == "image_qna":
|
@@ -168,33 +168,35 @@ def respond(message, history):
|
|
168 |
|
169 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
170 |
thread.start()
|
171 |
-
|
172 |
buffer = ""
|
173 |
for new_text in streamer:
|
174 |
-
|
175 |
-
|
|
|
176 |
else:
|
177 |
messages = f"<|im_start|>system\n π, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|im_end|>"
|
178 |
for msg in history:
|
179 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
180 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
181 |
-
messages+=f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
|
182 |
stream = client_yi.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
183 |
output = ""
|
184 |
for response in stream:
|
185 |
-
if
|
186 |
output += response.token.text
|
187 |
yield output
|
188 |
except:
|
189 |
-
|
|
|
190 |
for msg in history:
|
191 |
-
messages += f"\n<|
|
192 |
-
messages += f"\n<|
|
193 |
-
messages+=f"\n<|
|
194 |
stream = client_llama.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
195 |
output = ""
|
196 |
for response in stream:
|
197 |
-
if
|
198 |
output += response.token.text
|
199 |
yield output
|
200 |
|
@@ -205,6 +207,9 @@ demo = gr.ChatInterface(
|
|
205 |
textbox=gr.MultimodalTextbox(),
|
206 |
multimodal=True,
|
207 |
concurrency_limit=200,
|
208 |
-
cache_examples=False,
|
|
|
209 |
)
|
|
|
|
|
210 |
demo.launch()
|
|
|
14 |
import cv2
|
15 |
from gradio_client import Client, file
|
16 |
|
17 |
+
# Function to generate an image using another model
|
18 |
def image_gen(prompt):
|
19 |
client = Client("KingNish/Image-Gen-Pro")
|
20 |
+
return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")
|
21 |
|
22 |
+
# Load the processor and model for image-based QnA (LLaVA model)
|
23 |
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
|
|
|
24 |
processor = LlavaProcessor.from_pretrained(model_id)
|
|
|
25 |
model = LlavaForConditionalGeneration.from_pretrained(model_id)
|
26 |
model.to("cpu")
|
27 |
|
28 |
+
# Function to process images with text input
|
29 |
def llava(message, history):
|
30 |
if message["files"]:
|
31 |
image = message["files"][0]
|
32 |
else:
|
33 |
for hist in history:
|
34 |
+
if type(hist[0]) == tuple:
|
35 |
image = hist[0][0]
|
36 |
|
37 |
txt = message["text"]
|
|
|
43 |
inputs = processor(prompt, image, return_tensors="pt")
|
44 |
return inputs
|
45 |
|
46 |
+
# Helper function to extract text from a webpage
|
47 |
def extract_text_from_webpage(html_content):
|
48 |
soup = BeautifulSoup(html_content, 'html.parser')
|
49 |
for tag in soup(["script", "style", "header", "footer"]):
|
50 |
tag.extract()
|
51 |
return soup.get_text(strip=True)
|
52 |
|
53 |
+
# Function to search the web using Google
|
54 |
def search(query):
|
55 |
term = query
|
56 |
start = 0
|
|
|
90 |
# Define the main chat function
|
91 |
def respond(message, history):
|
92 |
func_caller = []
|
|
|
93 |
user_prompt = message
|
94 |
+
|
95 |
# Handle image processing
|
96 |
if message["files"]:
|
97 |
inputs = llava(message, history)
|
|
|
103 |
|
104 |
buffer = ""
|
105 |
for new_text in streamer:
|
106 |
+
if new_text not in ["<|im_end|>", "<|endoftext|>"]: # Ignore special tokens
|
107 |
+
buffer += new_text
|
108 |
+
yield buffer
|
109 |
else:
|
110 |
+
# Functions metadata for invoking different models or functions
|
111 |
functions_metadata = [
|
112 |
{"type": "function", "function": {"name": "web_search", "description": "Search query on google", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "web search query"}}, "required": ["query"]}}},
|
113 |
{"type": "function", "function": {"name": "general_query", "description": "Reply general query of USER", "parameters": {"type": "object", "properties": {"prompt": {"type": "string", "description": "A detailed prompt"}}, "required": ["prompt"]}}},
|
|
|
124 |
|
125 |
response = client_gemma.chat_completion(func_caller, max_tokens=200)
|
126 |
response = str(response)
|
127 |
+
|
128 |
+
# Filtering and processing response
|
129 |
try:
|
130 |
response = response[int(response.find("{")):int(response.rindex("</"))]
|
131 |
except:
|
132 |
response = response[int(response.find("{")):(int(response.rfind("}"))+1)]
|
133 |
+
response = response.replace("\\n", "").replace("\\'", "'").replace('\\"', '"').replace('\\', '')
|
|
|
|
|
|
|
134 |
print(f"\n{response}")
|
135 |
|
136 |
try:
|
137 |
json_data = json.loads(str(response))
|
138 |
if json_data["name"] == "web_search":
|
139 |
query = json_data["arguments"]["query"]
|
|
|
140 |
web_results = search(query)
|
|
|
141 |
web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
|
142 |
messages = f"<|im_start|>system\n Hi π, I am Nora,mini a helpful assistant.Ask me! I will do my best!! <|im_end|>"
|
143 |
for msg in history:
|
144 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
145 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
146 |
+
messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>web_result\n{web2}<|im_end|>\n<|im_start|>assistant\n"
|
147 |
stream = client_mixtral.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
148 |
output = ""
|
149 |
for response in stream:
|
150 |
+
if not response.token.text in ["<|im_end|>", "<|endoftext|>"]: # Exclude special tokens
|
151 |
output += response.token.text
|
152 |
yield output
|
153 |
elif json_data["name"] == "image_generation":
|
154 |
query = json_data["arguments"]["query"]
|
|
|
155 |
yield "Generating Image, Please wait 10 sec..."
|
156 |
try:
|
157 |
image = image_gen(f"{str(query)}")
|
158 |
yield gr.Image(image[1])
|
159 |
except:
|
160 |
client_sd3 = InferenceClient("stabilityai/stable-diffusion-3-medium-diffusers")
|
161 |
+
seed = random.randint(0, 999999)
|
162 |
image = client_sd3.text_to_image(query, negative_prompt=f"{seed}")
|
163 |
yield gr.Image(image)
|
164 |
elif json_data["name"] == "image_qna":
|
|
|
168 |
|
169 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
170 |
thread.start()
|
171 |
+
|
172 |
buffer = ""
|
173 |
for new_text in streamer:
|
174 |
+
if new_text not in ["<|im_end|>", "<|endoftext|>"]: # Ignore special tokens
|
175 |
+
buffer += new_text
|
176 |
+
yield buffer
|
177 |
else:
|
178 |
messages = f"<|im_start|>system\n π, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|im_end|>"
|
179 |
for msg in history:
|
180 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
181 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
182 |
+
messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
|
183 |
stream = client_yi.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
184 |
output = ""
|
185 |
for response in stream:
|
186 |
+
if response.token.text not in ["<|im_end|>", "<|endoftext|>"]: # Ignore special tokens
|
187 |
output += response.token.text
|
188 |
yield output
|
189 |
except:
|
190 |
+
# Handle the case where JSON parsing or function calling fails
|
191 |
+
messages = f"<|im_start|>system\nHi π, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|im_end|>"
|
192 |
for msg in history:
|
193 |
+
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
194 |
+
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
195 |
+
messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
|
196 |
stream = client_llama.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False)
|
197 |
output = ""
|
198 |
for response in stream:
|
199 |
+
if response.token.text not in ["<|eot_id|>", "<|im_end|>"]: # Ignore special tokens
|
200 |
output += response.token.text
|
201 |
yield output
|
202 |
|
|
|
207 |
textbox=gr.MultimodalTextbox(),
|
208 |
multimodal=True,
|
209 |
concurrency_limit=200,
|
210 |
+
cache_examples=False,
|
211 |
+
css="footer{display:none !important}"
|
212 |
)
|
213 |
+
|
214 |
+
# Launch the Gradio app
|
215 |
demo.launch()
|