Image-Gen / app.py
Profakerr's picture
Update app.py
a2e4ef2 verified
raw
history blame
9.02 kB
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
import torch
import gradio as gr
import spaces
from huggingface_hub import hf_hub_download
import os
import requests
import hashlib
from pathlib import Path
import re
# Default LoRA for fallback
DEFAULT_LORA = "OedoSoldier/detail-tweaker-lora"
LORA_CACHE_DIR = "lora_cache"
def download_lora(url):
"""Download LoRA file from Civitai URL and cache it locally"""
# Create cache directory if it doesn't exist
os.makedirs(LORA_CACHE_DIR, exist_ok=True)
# Generate a filename from the URL
url_hash = hashlib.md5(url.encode()).hexdigest()
local_path = os.path.join(LORA_CACHE_DIR, f"{url_hash}.safetensors")
# If file already exists in cache, return the path
if os.path.exists(local_path):
print()
print("********** Lora Already Exists **********")
print()
return local_path
# Download the file
try:
response = requests.get(url, stream=True)
response.raise_for_status()
# Get the total file size
total_size = int(response.headers.get('content-length', 0))
# Download and save the file
with open(local_path, 'wb') as f:
if total_size == 0:
f.write(response.content)
else:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print()
print("********** Lora Downloading Successfull **********")
print()
return local_path
except Exception as e:
print()
print(f"Error downloading LoRA: {str(e)}")
print()
return None
def is_civitai_url(url):
"""Check if the URL is a valid Civitai download URL"""
return bool(re.match(r'https?://civitai\.com/api/download/models/\d+', url))
@spaces.GPU
def generate_image(prompt, negative_prompt, lora_url, num_inference_steps=30, guidance_scale=7.0,
model="Real6.0", num_images=1, width=512, height=512):
if model == "Real5.0":
model_id = "SG161222/Realistic_Vision_V5.0_noVAE"
elif model == "Real5.1":
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
else:
model_id = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
# Initialize models
vae = AutoencoderKL.from_pretrained(
model_id,
subfolder="vae"
).to("cuda")
text_encoder = CLIPTextModel.from_pretrained(
model_id,
subfolder="text_encoder"
).to("cuda")
tokenizer = CLIPTokenizer.from_pretrained(
model_id,
subfolder="tokenizer"
)
unet = UNet2DConditionModel.from_pretrained(
model_id,
subfolder="unet"
).to("cuda")
pipe = DiffusionPipeline.from_pretrained(
model_id,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae
).to("cuda")
# Load LoRA weights
try:
if lora_url and lora_url.strip():
if is_civitai_url(lora_url):
# Download and load Civitai LoRA
lora_path = download_lora(lora_url)
if lora_path:
pipe.load_lora_weights(lora_path)
print()
print("********** URL Lora Loaded **********")
print()
else:
pipe.load_lora_weights(DEFAULT_LORA)
print()
print("********** Default Lora Loaded **********")
print()
# If it's a HuggingFace repo path
elif '/' in lora_url and not lora_url.startswith('http'):
pipe.load_lora_weights(lora_url)
print()
print("********** URL Lora Loaded **********")
print()
else:
pipe.load_lora_weights(DEFAULT_LORA)
print()
print("********** Default Lora Loaded **********")
print()
else:
pipe.load_lora_weights(DEFAULT_LORA)
except Exception as e:
print()
print(f"Error loading LoRA weights: {str(e)}")
print()
pipe.load_lora_weights(DEFAULT_LORA)
if model == "Real6.0":
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
pipe.scheduler = DPMSolverMultistepScheduler.from_config(
pipe.scheduler.config,
algorithm_type="dpmsolver++",
use_karras_sigmas=True
)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
).to("cuda")
negative_text_inputs = tokenizer(
negative_prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
).to("cuda")
prompt_embeds = text_encoder(text_inputs.input_ids)[0]
negative_prompt_embeds = text_encoder(negative_text_inputs.input_ids)[0]
# Generate the image
result = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
cross_attention_kwargs={"scale": 1},
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
num_images_per_prompt=num_images
)
return result.images
def clean_lora_cache():
"""Clean the LoRA cache directory"""
if os.path.exists(LORA_CACHE_DIR):
for file in os.listdir(LORA_CACHE_DIR):
file_path = os.path.join(LORA_CACHE_DIR, file)
try:
if os.path.isfile(file_path):
os.unlink(file_path)
except Exception as e:
print(f"Error deleting {file_path}: {str(e)}")
title = """<h1 align="center">ProFaker</h1>"""
# Create the Gradio interface
with gr.Blocks() as demo:
gr.HTML(title)
with gr.Row():
with gr.Column():
# Input components
prompt = gr.Textbox(
label="Prompt",
info="Enter your image description here...",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
info="Enter what you don't want in Image...",
lines=3
)
lora_input = gr.Textbox(
label="LoRA URL/Path",
info="Enter Civitai download URL or HuggingFace path (e.g., 'username/model-name')",
value=DEFAULT_LORA
)
clear_cache = gr.Button("Clear LoRA Cache")
generate_button = gr.Button("Generate Image")
with gr.Accordion("Advanced Options", open=False):
model = gr.Dropdown(
choices=["Real6.0","Real5.1","Real5.0"],
value="Real6.0",
label="Model",
)
num_images = gr.Slider(
minimum=1,
maximum=4,
value=1,
step=1,
label="Number of Images to Generate"
)
width = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Image Width"
)
height = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Image Height"
)
steps_slider = gr.Slider(
minimum=1,
maximum=100,
value=30,
step=1,
label="Number of Steps"
)
guidance_slider = gr.Slider(
minimum=1,
maximum=10,
value=7.0,
step=0.5,
label="Guidance Scale"
)
with gr.Column():
# Output component
gallery = gr.Gallery(
label="Generated Images",
show_label=True,
elem_id="gallery",
columns=2,
rows=2
)
# Connect the interface to the generation function
generate_button.click(
fn=generate_image,
inputs=[prompt, negative_prompt, lora_input, steps_slider, guidance_slider,
model, num_images, width, height],
outputs=gallery
)
# Connect clear cache button
clear_cache.click(fn=clean_lora_cache)
demo.queue(max_size=10).launch(share=False)