File size: 2,648 Bytes
c2a8649 6031dc7 1888310 6031dc7 9af81fd c2a8649 6031dc7 9af81fd 1888310 6031dc7 1888310 6031dc7 234658e 6031dc7 1888310 6031dc7 c2a8649 9af81fd 6031dc7 1888310 9af81fd c2a8649 1888310 6031dc7 c2a8649 6031dc7 1888310 6031dc7 1888310 c2a8649 6031dc7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
# Function to automatically switch between GPU and CPU
def load_model(base_model_id, adapter_model_id=None):
if torch.cuda.is_available():
device = "cuda"
info = "Running on GPU (CUDA) 🔥"
else:
device = "cpu"
info = "Running on CPU 🥶"
# Load the base model dynamically on the correct device
pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
pipe = pipe.to(device)
# If an adapter model is provided, load and merge the adapter model
if adapter_model_id:
adapter_model = StableDiffusionPipeline.from_pretrained(adapter_model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32)
pipe.unet.load_attn_procs(adapter_model_id) # This applies the adapter like LoRA to the model's UNet
info += f" with Adapter Model: {adapter_model_id}"
return pipe, info
# Function for text-to-image generation with dynamic model ID and device info
def generate_image(base_model_id, adapter_model_id, prompt):
pipe, info = load_model(base_model_id, adapter_model_id)
image = pipe(prompt).images[0]
return image, info
# Check device (GPU/CPU) once at the start and show it in the UI
if torch.cuda.is_available():
device = "cuda"
info = "Running on GPU (CUDA) 🔥"
else:
device = "cpu"
info = "Running on CPU 🥶"
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Custom Text-to-Image Generator with Adapter Support")
gr.Markdown(f"**{info}**") # Display GPU/CPU information in the UI
with gr.Row():
with gr.Column():
base_model_id = gr.Textbox(label="Enter Base Model ID (e.g., CompVis/stable-diffusion-v1-4)", placeholder="Base Model ID")
adapter_model_id = gr.Textbox(label="Enter Adapter Model ID (optional, e.g., nevreal/vMurderDrones-Lora)", placeholder="Adapter Model ID (optional)", value="")
prompt = gr.Textbox(label="Enter your prompt", placeholder="Describe the image you want to generate")
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image")
device_info = gr.Markdown() # To display if GPU or CPU is used and whether an adapter is applied
# Link the button to the image generation function
generate_btn.click(fn=generate_image, inputs=[base_model_id, adapter_model_id, prompt], outputs=[output_image, device_info])
# Launch the app
demo.launch()
|