|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
import json |
|
import uuid |
|
import warnings |
|
from PIL import Image |
|
from bs4 import BeautifulSoup |
|
import requests |
|
import random |
|
from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer |
|
from threading import Thread |
|
import re |
|
import time |
|
import torch |
|
import cv2 |
|
from gradio_client import Client, file |
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, module='gradio') |
|
|
|
|
|
def image_gen(prompt): |
|
client = Client("KingNish/Image-Gen-Pro") |
|
return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro") |
|
|
|
|
|
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf" |
|
processor = LlavaProcessor.from_pretrained(model_id) |
|
model = LlavaForConditionalGeneration.from_pretrained(model_id) |
|
model.to("cpu") |
|
|
|
|
|
def llava(message, history): |
|
if message["files"]: |
|
image = message["files"][0] |
|
else: |
|
for hist in history: |
|
if type(hist[0]) == tuple: |
|
image = hist[0][0] |
|
|
|
txt = message["text"] |
|
|
|
gr.Info("Analyzing image") |
|
image = Image.open(image).convert("RGB") |
|
prompt = f"<|im_start|>user <image>\n{txt}<|im_end|><|im_start|>assistant" |
|
|
|
inputs = processor(prompt, image, return_tensors="pt") |
|
return inputs |
|
|
|
|
|
def extract_text_from_webpage(html_content): |
|
soup = BeautifulSoup(html_content, 'html.parser') |
|
for tag in soup(["script", "style", "header", "footer"]): |
|
tag.extract() |
|
return soup.get_text(strip=True) |
|
|
|
|
|
def search(query): |
|
term = query |
|
start = 0 |
|
all_results = [] |
|
max_chars_per_page = 8000 |
|
with requests.Session() as session: |
|
resp = session.get( |
|
url="https://www.google.com/search", |
|
headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}, |
|
params={"q": term, "num": 3, "udm": 14}, |
|
timeout=5, |
|
verify=None, |
|
) |
|
resp.raise_for_status() |
|
soup = BeautifulSoup(resp.text, "html.parser") |
|
result_block = soup.find_all("div", attrs={"class": "g"}) |
|
for result in result_block: |
|
link = result.find("a", href=True) |
|
link = link["href"] |
|
try: |
|
webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}, timeout=5, verify=False) |
|
webpage.raise_for_status() |
|
visible_text = extract_text_from_webpage(webpage.text) |
|
if len(visible_text) > max_chars_per_page: |
|
visible_text = visible_text[:max_chars_per_page] |
|
all_results.append({"link": link, "text": visible_text}) |
|
except requests.exceptions.RequestException: |
|
all_results.append({"link": link, "text": None}) |
|
return all_results |
|
|
|
|
|
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3") |
|
client_mixtral = InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO") |
|
client_llama = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") |
|
client_yi = InferenceClient("01-ai/Yi-1.5-34B-Chat") |
|
|
|
|
|
def respond(message, history): |
|
func_caller = [] |
|
user_prompt = message |
|
|
|
|
|
if message["files"]: |
|
inputs = llava(message, history) |
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, **{"skip_special_tokens": True}) |
|
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
buffer = "" |
|
for new_text in streamer: |
|
if new_text not in ["<|im_end|>", "<|endoftext|>"]: |
|
buffer += new_text |
|
yield buffer |
|
else: |
|
|
|
functions_metadata = [ |
|
{"type": "function", "function": {"name": "web_search", "description": "Search query on google", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "web search query"}}, "required": ["query"]}}}, |
|
{"type": "function", "function": {"name": "general_query", "description": "Reply general query of USER", "parameters": {"type": "object", "properties": {"prompt": {"type": "string", "description": "A detailed prompt"}}, "required": ["prompt"]}}}, |
|
{"type": "function", "function": {"name": "image_generation", "description": "Generate image for user", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "image generation prompt"}}, "required": ["query"]}}}, |
|
{"type": "function", "function": {"name": "image_qna", "description": "Answer question asked by user related to image", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Question by user"}}, "required": ["query"]}}}, |
|
] |
|
|
|
for msg in history: |
|
func_caller.append({"role": "user", "content": f"{str(msg[0])}"}) |
|
func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"}) |
|
|
|
message_text = message["text"] |
|
func_caller.append({"role": "user", "content": f'[SYSTEM]You are a helpful assistant. You have access to the following functions: \n {str(functions_metadata)}\n\nTo use these functions respond with:\n<functioncall> {{ "name": "function_name", "arguments": {{ "arg_1": "value_1", "arg_1": "value_1", ... }} }} </functioncall> [USER] {message_text}'}) |
|
|
|
response = client_gemma.chat_completion(func_caller, max_tokens=200) |
|
response = str(response) |
|
|
|
|
|
try: |
|
response = response[int(response.find("{")):int(response.rindex("</"))] |
|
except: |
|
response = response[int(response.find("{")):(int(response.rfind("}"))+1)] |
|
response = response.replace("\\n", "").replace("\\'", "'").replace('\\"', '"').replace('\\', '') |
|
print(f"\n{response}") |
|
|
|
try: |
|
json_data = json.loads(str(response)) |
|
if json_data["name"] == "web_search": |
|
query = json_data["arguments"]["query"] |
|
web_results = search(query) |
|
web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results]) |
|
messages = f"<|im_start|>system\n Hi π, I am Nora,mini a helpful assistant.Ask me! I will do my best!! <|im_end|>" |
|
for msg in history: |
|
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>" |
|
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>" |
|
messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>web_result\n{web2}<|im_end|>\n<|im_start|>assistant\n" |
|
stream = client_mixtral.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
for response in stream: |
|
if not response.token.text in ["<|im_end|>", "<|endoftext|>"]: |
|
output += response.token.text |
|
yield output |
|
elif json_data["name"] == "image_generation": |
|
query = json_data["arguments"]["query"] |
|
yield "Generating Image, Please wait 10 sec..." |
|
try: |
|
image = image_gen(f"{str(query)}") |
|
yield gr.Image(image[1]) |
|
except: |
|
client_sd3 = InferenceClient("stabilityai/stable-diffusion-3-medium-diffusers") |
|
seed = random.randint(0, 999999) |
|
image = client_sd3.text_to_image(query, negative_prompt=f"{seed}") |
|
yield gr.Image(image) |
|
elif json_data["name"] == "image_qna": |
|
inputs = llava(message, history) |
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, **{"skip_special_tokens": True}) |
|
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024) |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
buffer = "" |
|
for new_text in streamer: |
|
if new_text not in ["<|im_end|>", "<|endoftext|>"]: |
|
buffer += new_text |
|
yield buffer |
|
else: |
|
messages = f"<|im_start|>system\n π, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|im_end|>" |
|
for msg in history: |
|
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>" |
|
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>" |
|
messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n" |
|
stream = client_yi.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
for response in stream: |
|
if response.token.text not in ["<|im_end|>", "<|endoftext|>"]: |
|
output += response.token.text |
|
yield output |
|
except: |
|
|
|
messages = f"<|im_start|>system\nHi π, I am Nora,mini a helpful assistant.Ask me! I will do my best!!<|im_end|>" |
|
for msg in history: |
|
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>" |
|
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>" |
|
messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n" |
|
stream = client_llama.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True, details=True, return_full_text=False) |
|
output = "" |
|
for response in stream: |
|
if response.token.text not in ["<|eot_id|>", "<|im_end|>"]: |
|
output += response.token.text |
|
yield output |
|
|
|
|
|
demo = gr.ChatInterface( |
|
fn=respond, |
|
chatbot=gr.Chatbot(layout="panel"), |
|
textbox=gr.MultimodalTextbox(), |
|
multimodal=True, |
|
concurrency_limit=200, |
|
cache_examples=False, |
|
css="footer{display:none !important}" |
|
) |
|
|
|
|
|
demo.launch() |