Fooocus_app / app.py
uruguayai's picture
Update app.py
6cbe8da verified
raw
history blame
2.52 kB
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import numpy as np
from torchvision import transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the inpainting model
try:
inpaint_model = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1").to(device)
except Exception as e:
print(f"Error initializing model: {e}")
inpaint_model = None
# Load a test image
def load_test_image():
try:
# Provide the absolute path to a test image
return Image.open("/absolute/path/to/your/test_image.png")
except Exception as e:
print(f"Error loading test image: {e}")
return None
def process_image(prompt, image, style, upscale_factor, inpaint):
if image is None:
image = load_test_image()
if image is None:
return None, "No image received and failed to load test image."
try:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, torch.Tensor):
image = transforms.ToPILImage()(image)
elif not isinstance(image, Image.Image):
return None, f"Unsupported image format: {type(image)}."
print(f"Received image: {image.size}")
# Placeholder for processing logic
return image, None
except Exception as e:
error_message = f"Error in process_image function: {e}"
print(error_message)
return None, error_message
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Enter your prompt")
image_input = gr.Image(label="Image (for inpainting)", type="pil")
style_input = gr.Dropdown(choices=["Fooocus Style", "SAI Anime"], label="Select Style")
upscale_input = gr.Slider(minimum=1, maximum=4, step=1, label="Upscale Factor")
inpaint_input = gr.Checkbox(label="Enable Inpainting")
output_image = gr.Image(label="Generated Image", type="pil")
error_output = gr.Textbox(label="Error Details", lines=4, placeholder="Error details will appear here")
generate_button = gr.Button("Generate Image")
generate_button.click(
process_image,
inputs=[prompt_input, image_input, style_input, upscale_input, inpaint_input],
outputs=[output_image, error_output]
)
demo.launch()