# Import necessary libraries | |
from diffusers import StableDiffusionPipeline # Import StableDiffusionPipeline for image generation | |
import torch # Import PyTorch for deep learning operations | |
import gradio as gr # Import Gradio for creating a web interface | |
# Define configuration parameters | |
class CFG: | |
image_gen_steps = 35 # Number of steps for image generation | |
image_gen_model_id = "stabilityai/stable-diffusion-2" # ID of the StableDiffusion model | |
image_gen_size = (400, 400) # Size of the generated image | |
image_gen_guidance_scale = 9 # Guidance scale for image generation | |
# Load the StableDiffusion model | |
image_gen_model = StableDiffusionPipeline.from_pretrained( | |
CFG.image_gen_model_id, | |
revision="fp16", | |
guidance_scale=9 | |
) | |
# Define a function for image generation | |
def generate_image(prompt): | |
# Generate an image from a text prompt using the loaded model | |
image = image_gen_model( | |
prompt, | |
num_inference_steps=CFG.image_gen_steps, | |
guidance_scale=CFG.image_gen_guidance_scale | |
).images[0] | |
# Resize the generated image to the specified size | |
image = image.resize(CFG.image_gen_size) | |
return image # Return the generated image as the result | |
# Define a Gradio interface | |
iface = gr.Interface( | |
fn=generate_image, # Use the generate_image function for processing input | |
inputs="text", # Accept text input from the user | |
outputs="image", # Display the generated image as output | |
title="StableDiffusion Image Generation", # Set the title of the web interface | |
description="Generate images from text prompts using StableDiffusion model.", # Provide a description | |
live=False # Set to False if you don't want real-time updates (for beginner-friendly interaction) | |
) | |
# Start the Gradio interface | |
iface.launch(debug=True) |