Fooocus_app / app.py
uruguayai's picture
Update app.py
6cbe8da verified
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import numpy as np
from torchvision import transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the inpainting model
try:
inpaint_model = DiffusionPipeline.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1").to(device)
except Exception as e:
print(f"Error initializing model: {e}")
inpaint_model = None
# Load a test image
def load_test_image():
try:
# Provide the absolute path to a test image
return Image.open("/absolute/path/to/your/test_image.png")
except Exception as e:
print(f"Error loading test image: {e}")
return None
def process_image(prompt, image, style, upscale_factor, inpaint):
if image is None:
image = load_test_image()
if image is None:
return None, "No image received and failed to load test image."
try:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, torch.Tensor):
image = transforms.ToPILImage()(image)
elif not isinstance(image, Image.Image):
return None, f"Unsupported image format: {type(image)}."
print(f"Received image: {image.size}")
# Placeholder for processing logic
return image, None
except Exception as e:
error_message = f"Error in process_image function: {e}"
print(error_message)
return None, error_message
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(label="Enter your prompt")
image_input = gr.Image(label="Image (for inpainting)", type="pil")
style_input = gr.Dropdown(choices=["Fooocus Style", "SAI Anime"], label="Select Style")
upscale_input = gr.Slider(minimum=1, maximum=4, step=1, label="Upscale Factor")
inpaint_input = gr.Checkbox(label="Enable Inpainting")
output_image = gr.Image(label="Generated Image", type="pil")
error_output = gr.Textbox(label="Error Details", lines=4, placeholder="Error details will appear here")
generate_button = gr.Button("Generate Image")
generate_button.click(
process_image,
inputs=[prompt_input, image_input, style_input, upscale_input, inpaint_input],
outputs=[output_image, error_output]
)
demo.launch()