DrawNGuess / app.py
LIU, Zichen
Initial Commit
0e84795
raw
history blame
1.57 kB
import gradio as gr
import random
import torch
import numpy as np
from PIL import Image, ImageOps
from fastapi import FastAPI, Request
from MagicQuill import folder_paths
from MagicQuill.llava_new import LLaVAModel
from huggingface_hub import snapshot_download
snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
llavaModel = LLaVAModel()
def numpy_to_tensor(numpy_array):
tensor = torch.from_numpy(numpy_array).float().unsqueeze(0) / 255.
return tensor
def guess(original_image_tensor, add_color_image_tensor, add_edge_mask):
# print("original_image_tensor:", original_image_tensor.shape)
# print("add_color_image_tensor:", add_color_image_tensor.shape)
# print("add_edge_mask:", add_edge_mask.shape)
original_image_tensor = numpy_to_tensor(original_image_tensor)
add_color_image_tensor = numpy_to_tensor(add_color_image_tensor)
add_edge_mask = numpy_to_tensor(add_edge_mask)
description, ans1, ans2 = llavaModel.process(original_image_tensor, add_color_image_tensor, add_edge_mask)
ans_list = []
if ans1 and ans1 != "":
ans_list.append(ans1)
if ans2 and ans2 != "":
ans_list.append(ans2)
return ", ".join(ans_list)
# 简化 Gradio 接口,参考官方格式
gr.Interface(
fn=guess,
inputs=[gr.Image(label="Original Image"),
gr.Image(label="Colored Image"),
gr.Image(image_mode="L", label="Edge Mask")],
outputs=gr.Textbox(label="Prediction Output")
).queue(max_size=40, status_update_rate=0.1).launch(max_threads=4)