Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from PIL import Image | |
import requests | |
from bs4 import BeautifulSoup | |
import json | |
import uuid | |
# ===================== 核心逻辑模块 ===================== | |
# 初始化所需的模型客户端 | |
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3") | |
client_mixtral = InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO") | |
client_llama = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") | |
client_yi = InferenceClient("01-ai/Yi-1.5-34B-Chat") | |
# ---------- 服务状态检查模块 ---------- | |
def check_service_status(): | |
""" | |
检查各个服务的可用状态,返回服务状态字典。 | |
""" | |
services = { | |
"Gemma": check_inference_client(client_gemma), | |
"Mixtral": check_inference_client(client_mixtral), | |
"Llama": check_inference_client(client_llama), | |
"Yi": check_inference_client(client_yi), | |
} | |
return services | |
def check_inference_client(client): | |
""" | |
尝试发送简单请求以检查服务的可用性。 | |
""" | |
try: | |
# 发送一个简单的预测请求以验证可用性,使用空文本预测请求 | |
response = client.predict({"inputs": ""}) | |
return True if response else False | |
except Exception: | |
return False | |
def get_service_status_markdown(): | |
""" | |
格式化服务状态为 Markdown 文本,用于界面展示。 | |
""" | |
statuses = check_service_status() | |
status_text = "\n".join([f"{service}: {'🟢 可用' if available else '🔴 不可用'}" for service, available in statuses.items()]) | |
return status_text # 返回字符串而不是 gr.Markdown 对象 | |
# ---------- 图像生成模块 ---------- | |
def image_gen(prompt): | |
""" | |
调用图像生成模型生成图像,返回生成的图像路径。 | |
""" | |
client = InferenceClient("KingNish/Image-Gen-Pro") | |
response = client.predict("Image Generation", None, prompt, api_name="/image_gen_pro") | |
image = response.get("image") # 假设返回的结果包含图像 | |
return image | |
# ---------- 文本和图像问答模块 ---------- | |
def process_llava_input(message, history, processor): | |
""" | |
处理 LLaVA 图像问答输入,提取文本与图像输入,生成模型输入。 | |
""" | |
image = None | |
if message["files"]: | |
image = message["files"][0] # 如果有上传的图像文件 | |
else: | |
for hist in history: | |
if isinstance(hist[0], tuple): | |
image = hist[0][0] # 从历史记录中提取最后一个图像 | |
txt = message["text"] | |
image = Image.open(image).convert("RGB") | |
prompt = f"<|im_start|>user <image>\n{txt}<|im_end|><|im_start|>assistant" | |
inputs = processor(prompt, image, return_tensors="pt") | |
return inputs | |
def llava_answer(inputs, model): | |
""" | |
调用 LLaVA 模型回答图像问答请求,返回回答结果。 | |
""" | |
# 使用模型生成回答的逻辑 | |
output = model.generate(**inputs) | |
answer = output[0]["generated_text"] # 假设模型返回文本在 `generated_text` 字段 | |
return answer | |
# ---------- 网络搜索模块 ---------- | |
def search(query): | |
""" | |
执行网络搜索,返回搜索结果标题和链接。 | |
""" | |
search_results = [] | |
with requests.Session() as session: | |
resp = session.get("https://www.google.com/search", params={"q": query, "num": 3}) | |
soup = BeautifulSoup(resp.text, "html.parser") | |
# 提取搜索结果的标题和链接 | |
for item in soup.select('div.g'): | |
title_element = item.select_one("h3") | |
link_element = item.select_one("a") | |
if title_element and link_element: | |
title = title_element.get_text() | |
link = link_element["href"] | |
search_results.append((title, link)) | |
return search_results | |
# ---------- 回答生成模块 ---------- | |
def respond(message, history, client): | |
""" | |
根据输入的消息和历史记录,选择合适的模型生成回答。 | |
""" | |
# 使用指定的模型 client 来生成回答 | |
response = client.predict({"inputs": message}) | |
answer = response.get("generated_text") # 假设返回结果包含生成的文本 | |
return answer | |
# ===================== Gradio 界面构建 ===================== | |
def build_interface(): | |
""" | |
构建 Gradio 界面布局,包括服务状态栏、文本聊天、图像生成和图像问答选项卡。 | |
""" | |
with gr.Blocks() as demo: | |
# 服务状态栏 | |
gr.Markdown("# 服务状态") | |
gr.Markdown(get_service_status_markdown()) # 直接传入字符串 | |
# 多模态交互主界面 | |
with gr.Tab("文本聊天"): | |
chat_textbox = gr.Textbox(label="输入你的问题", placeholder="输入文本...") | |
chat_output = gr.Chatbot() | |
chat_button = gr.Button("发送") | |
with gr.Tab("图像生成"): | |
image_prompt = gr.Textbox(label="图像提示词", placeholder="输入描述来生成图像") | |
image_output = gr.Image() | |
image_button = gr.Button("生成图像") | |
with gr.Tab("图像问答"): | |
image_upload = gr.Image(label="上传图像") | |
image_question = gr.Textbox(label="提问", placeholder="输入关于图像的问题") | |
answer_output = gr.Textbox(label="回答") | |
answer_button = gr.Button("回答") | |
# 各个按钮的点击事件 | |
chat_button.click(lambda msg, hist: respond(msg, hist, client_gemma), inputs=[chat_textbox, chat_output], outputs=chat_output) | |
image_button.click(image_gen, inputs=image_prompt, outputs=image_output) | |
answer_button.click(lambda msg, hist: llava_answer(process_llava_input(msg, hist, processor)), inputs=[image_upload, image_question], outputs=answer_output) | |
gr.Markdown("### 说明") | |
gr.Markdown("该助手支持文本聊天、图像生成和图像问答等功能。根据不同需求选择对应的选项卡使用。") | |
return demo | |
# 启动 Gradio 界面 | |
if __name__ == "__main__": | |
demo = build_interface() | |
demo.launch() | |