Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from PIL import Image | |
import requests | |
import json | |
import uuid | |
# ===================== 核心逻辑模块 ===================== | |
# 初始化所需的模型客户端 | |
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3") | |
client_mixtral = InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO") | |
client_llama = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") | |
client_yi = InferenceClient("01-ai/Yi-1.5-34B-Chat") | |
# ---------- 服务状态检查模块 ---------- | |
def check_service_status(): | |
""" | |
检查各个服务的可用状态,返回服务状态字典。 | |
""" | |
services = { | |
"Gemma": client_gemma.is_available(), | |
"Mixtral": client_mixtral.is_available(), | |
"Llama": client_llama.is_available(), | |
"Yi": client_yi.is_available(), | |
} | |
return services | |
def get_service_status_markdown(): | |
""" | |
格式化服务状态为 Markdown 文本,用于界面展示。 | |
""" | |
statuses = check_service_status() | |
status_text = "\n".join([f"{service}: {'🟢 可用' if available else '🔴 不可用'}" for service, available in statuses.items()]) | |
return gr.Markdown(status_text) | |
# ---------- 图像生成模块 ---------- | |
def image_gen(prompt): | |
""" | |
调用图像生成模型生成图像,返回生成的图像路径。 | |
""" | |
client = Client("KingNish/Image-Gen-Pro") | |
image = client.predict("Image Generation", None, prompt, api_name="/image_gen_pro") | |
return image | |
# ---------- 文本和图像问答模块 ---------- | |
def process_llava_input(message, history, processor): | |
""" | |
处理 LLaVA 图像问答输入,提取文本与图像输入,生成模型输入。 | |
""" | |
image = None | |
if message["files"]: | |
image = message["files"][0] # 如果有上传的图像文件 | |
else: | |
for hist in history: | |
if isinstance(hist[0], tuple): | |
image = hist[0][0] # 从历史记录中提取最后一个图像 | |
txt = message["text"] | |
image = Image.open(image).convert("RGB") | |
prompt = f"<|im_start|>user <image>\n{txt}<|im_end|><|im_start|>assistant" | |
inputs = processor(prompt, image, return_tensors="pt") | |
return inputs | |
def llava_answer(inputs, model): | |
""" | |
调用 LLaVA 模型回答图像问答请求。 | |
""" | |
# 这里调用模型生成回答的逻辑,返回回答结果(省略实现细节) | |
answer = model.generate_answer(inputs) | |
return answer | |
# ---------- 网络搜索模块 ---------- | |
def search(query): | |
""" | |
执行网络搜索,返回搜索结果文本。 | |
""" | |
search_results = [] | |
with requests.Session() as session: | |
resp = session.get("https://www.google.com/search", params={"q": query, "num": 3}) | |
# TODO: 使用 BeautifulSoup 提取返回的搜索结果标题和链接 | |
# search_results = [(title, link), ...] | |
return search_results | |
# ---------- 回答生成模块 ---------- | |
def respond(message, history, client): | |
""" | |
根据输入的消息和历史记录,选择合适的模型生成回答。 | |
""" | |
# 根据输入的模型 client 来决定使用哪个模型生成回答 | |
response = client.predict(message) | |
return response | |
# ===================== Gradio 界面构建 ===================== | |
def build_interface(): | |
""" | |
构建 Gradio 界面布局,包括服务状态栏、文本聊天、图像生成和图像问答选项卡。 | |
""" | |
with gr.Blocks() as demo: | |
# 服务状态栏 | |
gr.Markdown("# 服务状态") | |
gr.Markdown(get_service_status_markdown()) | |
# 多模态交互主界面 | |
with gr.Tab("文本聊天"): | |
chat_textbox = gr.Textbox(label="输入你的问题", placeholder="输入文本...") | |
chat_output = gr.Chatbot() | |
chat_button = gr.Button("发送") | |
with gr.Tab("图像生成"): | |
image_prompt = gr.Textbox(label="图像提示词", placeholder="输入描述来生成图像") | |
image_output = gr.Image() | |
image_button = gr.Button("生成图像") | |
with gr.Tab("图像问答"): | |
image_upload = gr.Image(label="上传图像") | |
image_question = gr.Textbox(label="提问", placeholder="输入关于图像的问题") | |
answer_output = gr.Textbox(label="回答") | |
answer_button = gr.Button("回答") | |
# 各个按钮的点击事件 | |
chat_button.click(lambda msg, hist: respond(msg, hist, client_gemma), inputs=[chat_textbox, chat_output], outputs=chat_output) | |
image_button.click(image_gen, inputs=image_prompt, outputs=image_output) | |
answer_button.click(lambda msg, hist: llava_answer(process_llava_input(msg, hist, processor)), inputs=[image_upload, image_question], outputs=answer_output) | |
gr.Markdown("### 说明") | |
gr.Markdown("该助手支持文本聊天、图像生成和图像问答等功能。根据不同需求选择对应的选项卡使用。") | |
return demo | |
# 启动 Gradio 界面 | |
if __name__ == "__main__": | |
demo = build_interface() | |
demo.launch() | |