Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
from functools import lru_cache | |
import gradio as gr | |
from gradio_toggle import Toggle | |
import torch | |
from huggingface_hub import snapshot_download | |
from transformers import CLIPProcessor, CLIPModel, pipeline | |
import random | |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder | |
from xora.models.transformers.transformer3d import Transformer3DModel | |
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier | |
from xora.schedulers.rf import RectifiedFlowScheduler | |
from xora.pipelines.pipeline_xora_video import XoraVideoPipeline | |
from transformers import T5EncoderModel, T5Tokenizer | |
from xora.utils.conditioning_method import ConditioningMethod | |
from pathlib import Path | |
import safetensors.torch | |
import json | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import tempfile | |
import os | |
import gc | |
import csv | |
from datetime import datetime | |
from openai import OpenAI | |
import argparse | |
import time | |
from os import path | |
import shutil | |
from datetime import datetime | |
from safetensors.torch import load_file | |
from diffusers import FluxPipeline | |
from diffusers.pipelines.stable_diffusion import safety_checker | |
from PIL import Image | |
from transformers import pipeline | |
import replicate | |
import logging | |
import requests | |
from pathlib import Path | |
import sys | |
import io | |
# ํ๊ธ-์์ด ๋ฒ์ญ๊ธฐ ์ด๊ธฐํ | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | |
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False | |
torch.backends.cudnn.allow_tf32 = False | |
torch.backends.cudnn.deterministic = False | |
torch.backends.cuda.preferred_blas_library="cublas" | |
torch.set_float32_matmul_precision("highest") | |
MAX_SEED = np.iinfo(np.int32).max | |
# Load Hugging Face token if needed | |
hf_token = os.getenv("HF_TOKEN") | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
client = OpenAI(api_key=openai_api_key) | |
system_prompt_t2v_path = "assets/system_prompt_t2v.txt" | |
with open(system_prompt_t2v_path, "r") as f: | |
system_prompt_t2v = f.read() | |
# Set model download directory within Hugging Face Spaces | |
model_path = "asset" | |
commit_hash='c7c8ad4c2ddba847b94e8bfaefbd30bd8669fafc' | |
if not os.path.exists(model_path): | |
snapshot_download("Lightricks/LTX-Video", revision=commit_hash, local_dir=model_path, repo_type="model", token=hf_token) | |
# Global variables to load components | |
vae_dir = Path(model_path) / "vae" | |
unet_dir = Path(model_path) / "unet" | |
scheduler_dir = Path(model_path) / "scheduler" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path).to(torch.device("cuda:0")) | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", cache_dir=model_path) | |
def process_prompt(prompt): | |
# ํ๊ธ์ด ํฌํจ๋์ด ์๋์ง ํ์ธ | |
if any(ord('๊ฐ') <= ord(char) <= ord('ํฃ') for char in prompt): | |
# ํ๊ธ์ ์์ด๋ก ๋ฒ์ญ | |
translated = translator(prompt)[0]['translation_text'] | |
return translated | |
return prompt | |
def compute_clip_embedding(text=None): | |
inputs = clip_processor(text=text, return_tensors="pt", padding=True).to(device) | |
outputs = clip_model.get_text_features(**inputs) | |
embedding = outputs.detach().cpu().numpy().flatten().tolist() | |
return embedding | |
def load_vae(vae_dir): | |
vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors" | |
vae_config_path = vae_dir / "config.json" | |
with open(vae_config_path, "r") as f: | |
vae_config = json.load(f) | |
vae = CausalVideoAutoencoder.from_config(vae_config) | |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path) | |
vae.load_state_dict(vae_state_dict) | |
return vae.to(device).to(torch.bfloat16) | |
def load_unet(unet_dir): | |
unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors" | |
unet_config_path = unet_dir / "config.json" | |
transformer_config = Transformer3DModel.load_config(unet_config_path) | |
transformer = Transformer3DModel.from_config(transformer_config) | |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path) | |
transformer.load_state_dict(unet_state_dict, strict=True) | |
return transformer.to(device).to(torch.bfloat16) | |
def load_scheduler(scheduler_dir): | |
scheduler_config_path = scheduler_dir / "scheduler_config.json" | |
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path) | |
return RectifiedFlowScheduler.from_config(scheduler_config) | |
# Preset options for resolution and frame configuration | |
preset_options = [ | |
{"label": "1216x704, 41 frames", "width": 1216, "height": 704, "num_frames": 41}, | |
{"label": "1088x704, 49 frames", "width": 1088, "height": 704, "num_frames": 49}, | |
{"label": "1056x640, 57 frames", "width": 1056, "height": 640, "num_frames": 57}, | |
{"label": "448x448, 100 frames", "width": 448, "height": 448, "num_frames": 100}, | |
{"label": "448x448, 200 frames", "width": 448, "height": 448, "num_frames": 200}, | |
{"label": "448x448, 300 frames", "width": 448, "height": 448, "num_frames": 300}, | |
{"label": "640x640, 80 frames", "width": 640, "height": 640, "num_frames": 80}, | |
{"label": "640x640, 120 frames", "width": 640, "height": 640, "num_frames": 120}, | |
{"label": "768x768, 64 frames", "width": 768, "height": 768, "num_frames": 64}, | |
{"label": "768x768, 90 frames", "width": 768, "height": 768, "num_frames": 90}, | |
{"label": "720x720, 64 frames", "width": 768, "height": 768, "num_frames": 64}, | |
{"label": "720x720, 100 frames", "width": 768, "height": 768, "num_frames": 100}, | |
{"label": "768x512, 97 frames", "width": 768, "height": 512, "num_frames": 97}, | |
{"label": "512x512, 160 frames", "width": 512, "height": 512, "num_frames": 160}, | |
{"label": "512x512, 200 frames", "width": 512, "height": 512, "num_frames": 200}, | |
] | |
def preset_changed(preset): | |
if preset != "Custom": | |
selected = next(item for item in preset_options if item["label"] == preset) | |
return ( | |
selected["height"], | |
selected["width"], | |
selected["num_frames"], | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
) | |
else: | |
return ( | |
None, | |
None, | |
None, | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
) | |
# Load models | |
vae = load_vae(vae_dir) | |
unet = load_unet(unet_dir) | |
scheduler = load_scheduler(scheduler_dir) | |
patchifier = SymmetricPatchifier(patch_size=1) | |
text_encoder = T5EncoderModel.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder").to(torch.device("cuda:0")) | |
tokenizer = T5Tokenizer.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer") | |
pipeline_video = XoraVideoPipeline( | |
transformer=unet, | |
patchifier=patchifier, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
scheduler=scheduler, | |
vae=vae, | |
).to(torch.device("cuda:0")) | |
def enhance_prompt_if_enabled(prompt, enhance_toggle): | |
if not enhance_toggle: | |
print("Enhance toggle is off, Prompt: ", prompt) | |
return prompt | |
messages = [ | |
{"role": "system", "content": system_prompt_t2v}, | |
{"role": "user", "content": prompt}, | |
] | |
try: | |
response = client.chat.completions.create( | |
model="gpt-4-mini", | |
messages=messages, | |
max_tokens=200, | |
) | |
print("Enhanced Prompt: ", response.choices[0].message.content.strip()) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
print(f"Error: {e}") | |
return prompt | |
def generate_video_from_text_90( | |
prompt="", | |
enhance_prompt_toggle=False, | |
negative_prompt="", | |
frame_rate=25, | |
seed=random.randint(0, MAX_SEED), | |
num_inference_steps=30, | |
guidance_scale=3.2, | |
height=768, | |
width=768, | |
num_frames=60, | |
progress=gr.Progress(), | |
): | |
# ํ๋กฌํํธ ์ ์ฒ๋ฆฌ (ํ๊ธ -> ์์ด) | |
prompt = process_prompt(prompt) | |
negative_prompt = process_prompt(negative_prompt) | |
if len(prompt.strip()) < 50: | |
raise gr.Error( | |
"Prompt must be at least 50 characters long. Please provide more details for the best results.", | |
duration=5, | |
) | |
prompt = enhance_prompt_if_enabled(prompt, enhance_prompt_toggle) | |
sample = { | |
"prompt": prompt, | |
"prompt_attention_mask": None, | |
"negative_prompt": negative_prompt, | |
"negative_prompt_attention_mask": None, | |
"media_items": None, | |
} | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
def gradio_progress_callback(self, step, timestep, kwargs): | |
progress((step + 1) / num_inference_steps) | |
try: | |
with torch.no_grad(): | |
images = pipeline_video( | |
num_inference_steps=num_inference_steps, | |
num_images_per_prompt=1, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
output_type="pt", | |
height=height, | |
width=width, | |
num_frames=num_frames, | |
frame_rate=frame_rate, | |
**sample, | |
is_video=True, | |
vae_per_channel_normalize=True, | |
conditioning_method=ConditioningMethod.UNCONDITIONAL, | |
mixed_precision=True, | |
callback_on_step_end=gradio_progress_callback, | |
).images | |
except Exception as e: | |
raise gr.Error( | |
f"An error occurred while generating the video. Please try again. Error: {e}", | |
duration=5, | |
) | |
finally: | |
torch.cuda.empty_cache() | |
gc.collect() | |
output_path = tempfile.mktemp(suffix=".mp4") | |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy() | |
video_np = (video_np * 255).astype(np.uint8) | |
height, width = video_np.shape[1:3] | |
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), frame_rate, (width, height)) | |
for frame in video_np[..., ::-1]: | |
out.write(frame) | |
out.release() | |
del images | |
del video_np | |
torch.cuda.empty_cache() | |
return output_path | |
def create_advanced_options(): | |
with gr.Accordion("Step 4: Advanced Options (Optional)", open=False): | |
seed = gr.Slider(label="4.1 Seed", minimum=0, maximum=1000000, step=1, value=646373) | |
inference_steps = gr.Slider(label="4.2 Inference Steps", minimum=5, maximum=150, step=5, value=40) | |
guidance_scale = gr.Slider(label="4.3 Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=4.2) | |
height_slider = gr.Slider( | |
label="4.4 Height", | |
minimum=256, | |
maximum=1024, | |
step=64, | |
value=768, | |
visible=False, | |
) | |
width_slider = gr.Slider( | |
label="4.5 Width", | |
minimum=256, | |
maximum=1024, | |
step=64, | |
value=768, | |
visible=False, | |
) | |
num_frames_slider = gr.Slider( | |
label="4.5 Number of Frames", | |
minimum=1, | |
maximum=500, | |
step=1, | |
value=60, | |
visible=False, | |
) | |
return [ | |
seed, | |
inference_steps, | |
guidance_scale, | |
height_slider, | |
width_slider, | |
num_frames_slider, | |
] | |
############################################### | |
# ์ฌ๊ธฐ์๋ถํฐ ๋ ๋ฒ์งธ ์ฝ๋ ํตํฉ ์ ์ฉ | |
############################################### | |
import argparse | |
import time | |
from os import path | |
import shutil | |
from safetensors.torch import load_file | |
from diffusers import FluxPipeline | |
from diffusers.pipelines.stable_diffusion import safety_checker | |
import replicate | |
import logging | |
import requests | |
from pathlib import Path | |
import sys | |
import io | |
# ๋ก๊น ์ค์ | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Setup and initialization code | |
cache_path = path.join(path.dirname(path.abspath(__file__)), "models") | |
PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".") | |
gallery_path = path.join(PERSISTENT_DIR, "gallery") | |
video_gallery_path = path.join(PERSISTENT_DIR, "video_gallery") | |
# API ์ค์ | |
CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5" | |
REPLICATE_API_TOKEN = os.getenv("API_KEY") | |
# ํ๊ฒฝ ๋ณ์ ์ค์ | |
os.environ["TRANSFORMERS_CACHE"] = cache_path | |
os.environ["HF_HUB_CACHE"] = cache_path | |
os.environ["HF_HOME"] = cache_path | |
# CUDA ์ค์ | |
torch.backends.cuda.matmul.allow_tf32 = True | |
# ๋ฒ์ญ๊ธฐ ์ด๊ธฐํ (์ด๋ฏธ ์์์ translator ์ ์ธ๋จ, ์ค๋ณต ์ ์ธ) | |
translator2 = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") # ๋ ๋ฒ์งธ ์ฝ๋์์๋ ์ ์ธ. ๋๋ฝ์์ด ์ถ๋ ฅํ๊ธฐ ์ํด ์ถ๊ฐ. | |
# ๋๋ ํ ๋ฆฌ ์์ฑ | |
for dir_path in [gallery_path, video_gallery_path]: | |
if not path.exists(dir_path): | |
os.makedirs(dir_path, exist_ok=True) | |
def check_api_key(): | |
"""API ํค ํ์ธ ๋ฐ ์ค์ """ | |
if not REPLICATE_API_TOKEN: | |
logger.error("Replicate API key not found") | |
return False | |
os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN | |
logger.info("Replicate API token set successfully") | |
return True | |
def translate_if_korean(text): | |
"""ํ๊ธ์ด ํฌํจ๋ ๊ฒฝ์ฐ ์์ด๋ก ๋ฒ์ญ""" | |
if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text): | |
translation = translator2(text)[0]['translation_text'] | |
return translation | |
return text | |
def filter_prompt(prompt): | |
inappropriate_keywords = [ | |
"nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx", | |
"erotic", "sensual", "seductive", "provocative", "intimate", | |
"violence", "gore", "blood", "death", "kill", "murder", "torture", | |
"drug", "suicide", "abuse", "hate", "discrimination" | |
] | |
prompt_lower = prompt.lower() | |
for keyword in inappropriate_keywords: | |
if keyword in prompt_lower: | |
return False, "๋ถ์ ์ ํ ๋ด์ฉ์ด ํฌํจ๋ ํ๋กฌํํธ์ ๋๋ค." | |
return True, prompt | |
def process_prompt_for_sd(prompt): | |
"""ํ๋กฌํํธ ์ ์ฒ๋ฆฌ (๋ฒ์ญ ๋ฐ ํํฐ๋ง)""" | |
translated_prompt = translate_if_korean(prompt) | |
is_safe, filtered_prompt = filter_prompt(translated_prompt) | |
return is_safe, filtered_prompt | |
class timer: | |
def __init__(self, method_name="timed process"): | |
self.method = method_name | |
def __enter__(self): | |
self.start = time.time() | |
print(f"{self.method} starts") | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
end = time.time() | |
print(f"{self.method} took {str(round(end - self.start, 2))}s") | |
# Model initialization | |
if not path.exists(cache_path): | |
os.makedirs(cache_path, exist_ok=True) | |
pipe_sd = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) | |
pipe_sd.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")) | |
pipe_sd.fuse_lora(lora_scale=0.125) | |
pipe_sd.to(device="cuda", dtype=torch.bfloat16) | |
pipe_sd.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker") | |
def upload_to_catbox(image_path): | |
"""catbox.moe API๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง ์ ๋ก๋""" | |
try: | |
logger.info(f"Preparing to upload image: {image_path}") | |
url = "https://catbox.moe/user/api.php" | |
file_extension = Path(image_path).suffix.lower() | |
if file_extension not in ['.jpg', '.jpeg', '.png', '.gif']: | |
logger.error(f"Unsupported file type: {file_extension}") | |
return None | |
files = { | |
'fileToUpload': ( | |
os.path.basename(image_path), | |
open(image_path, 'rb'), | |
'image/jpeg' if file_extension in ['.jpg', '.jpeg'] else 'image/png' | |
) | |
} | |
data = { | |
'reqtype': 'fileupload', | |
'userhash': CATBOX_USER_HASH | |
} | |
response = requests.post(url, files=files, data=data) | |
if response.status_code == 200 and response.text.startswith('http'): | |
image_url = response.text | |
logger.info(f"Image uploaded successfully: {image_url}") | |
return image_url | |
else: | |
raise Exception(f"Upload failed: {response.text}") | |
except Exception as e: | |
logger.error(f"Image upload error: {str(e)}") | |
return None | |
def add_watermark(video_path): | |
"""OpenCV๋ฅผ ์ฌ์ฉํ์ฌ ๋น๋์ค์ ์ํฐ๋งํฌ ์ถ๊ฐ""" | |
try: | |
cap = cv2.VideoCapture(video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
text = "GiniGEN.AI" | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
font_scale = height * 0.05 / 30 | |
thickness = 2 | |
color = (255, 255, 255) | |
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness) | |
margin = int(height * 0.02) | |
x_pos = width - text_width - margin | |
y_pos = height - margin | |
output_path = "watermarked_output.mp4" | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness) | |
out.write(frame) | |
cap.release() | |
out.release() | |
return output_path | |
except Exception as e: | |
logger.error(f"Error adding watermark: {str(e)}") | |
return video_path | |
def generate_video(image, prompt): | |
logger.info("Starting video generation") | |
try: | |
if not check_api_key(): | |
return "Replicate API key not properly configured" | |
if not image: | |
logger.error("No image provided") | |
return "Please upload an image" | |
image_url = upload_to_catbox(image) | |
if not image_url: | |
return "Failed to upload image" | |
input_data = { | |
"prompt": prompt, | |
"first_frame_image": image_url | |
} | |
try: | |
replicate.Client(api_token=REPLICATE_API_TOKEN) | |
output = replicate.run( | |
"minimax/video-01-live", | |
input=input_data | |
) | |
temp_file = "temp_output.mp4" | |
if hasattr(output, 'read'): | |
with open(temp_file, "wb") as file: | |
file.write(output.read()) | |
elif isinstance(output, str): | |
response = requests.get(output) | |
with open(temp_file, "wb") as file: | |
file.write(response.content) | |
final_video = add_watermark(temp_file) | |
return final_video | |
except Exception as api_error: | |
logger.error(f"API call failed: {str(api_error)}") | |
return f"API call failed: {str(api_error)}" | |
except Exception as e: | |
logger.error(f"Unexpected error: {str(e)}") | |
return f"Unexpected error: {str(e)}" | |
def save_image(image): | |
"""Save the generated image in PNG format and return the path""" | |
try: | |
if not os.path.exists(gallery_path): | |
os.makedirs(gallery_path, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
random_suffix = os.urandom(4).hex() | |
filename = f"generated_{timestamp}_{random_suffix}.png" | |
filepath = os.path.join(gallery_path, filename) | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image.save( | |
filepath, | |
format='PNG', | |
optimize=True, | |
quality=100 | |
) | |
logger.info(f"Image saved successfully as PNG: {filepath}") | |
return filepath | |
except Exception as e: | |
logger.error(f"Error in save_image: {str(e)}") | |
return None | |
def load_gallery(): | |
"""Load all images from the gallery directory""" | |
try: | |
os.makedirs(gallery_path, exist_ok=True) | |
image_files = [] | |
for f in os.listdir(gallery_path): | |
if f.lower().endswith(('.png', '.jpg', '.jpeg')): | |
full_path = os.path.join(gallery_path, f) | |
image_files.append((full_path, os.path.getmtime(full_path))) | |
image_files.sort(key=lambda x: x[1], reverse=True) | |
return [f[0] for f in image_files] | |
except Exception as e: | |
print(f"Error loading gallery: {str(e)}") | |
return [] | |
# CSS ์คํ์ผ ์ ์ | |
css = """ | |
[์ด์ ์ CSS ์ฝ๋๋ฅผ ๊ทธ๋๋ก ์ ์ง] | |
""" | |
def get_random_seed(): | |
return torch.randint(0, 1000000, (1,)).item() | |
############################################### | |
# ์ฌ๊ธฐ์๋ถํฐ Gradio UI ํตํฉ | |
############################################### | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
gr.HTML('<div class="title">AI Image & Video Generator</div>') | |
with gr.Tabs(): | |
with gr.Tab("Image Generation"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
img_prompt = gr.Textbox( | |
label="Image Description", | |
placeholder="์ด๋ฏธ์ง ์ค๋ช ์ ์ ๋ ฅํ์ธ์... (ํ๊ธ ์ ๋ ฅ ๊ฐ๋ฅ)", | |
lines=3 | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=1152, | |
step=64, | |
value=1024 | |
) | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=1152, | |
step=64, | |
value=1024 | |
) | |
with gr.Row(): | |
steps = gr.Slider( | |
label="Inference Steps", | |
minimum=6, | |
maximum=25, | |
step=1, | |
value=8 | |
) | |
scales = gr.Slider( | |
label="Guidance Scale", | |
minimum=0.0, | |
maximum=5.0, | |
step=0.1, | |
value=3.5 | |
) | |
seed = gr.Number( | |
label="Seed", | |
value=get_random_seed(), | |
precision=0 | |
) | |
randomize_seed = gr.Button("๐ฒ Randomize Seed", elem_classes=["generate-btn"]) | |
generate_btn = gr.Button( | |
"โจ Generate Image", | |
elem_classes=["generate-btn"] | |
) | |
with gr.Column(scale=4): | |
img_output = gr.Image( | |
label="Generated Image", | |
type="pil", | |
format="png" | |
) | |
img_gallery = gr.Gallery( | |
label="Image Gallery", | |
show_label=True, | |
elem_id="gallery", | |
columns=[4], | |
rows=[2], | |
height="auto", | |
object_fit="cover" | |
) | |
img_gallery.value = load_gallery() | |
with gr.Tab("Video Generation"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
video_prompt = gr.Textbox( | |
label="Video Description", | |
placeholder="๋น๋์ค ์ค๋ช ์ ์ ๋ ฅํ์ธ์... (ํ๊ธ ์ ๋ ฅ ๊ฐ๋ฅ)", | |
lines=3 | |
) | |
upload_image = gr.Image( | |
type="filepath", | |
label="Upload First Frame Image" | |
) | |
video_generate_btn = gr.Button( | |
"๐ฌ Generate Video", | |
elem_classes=["generate-btn"] | |
) | |
with gr.Column(scale=4): | |
video_output = gr.Video(label="Generated Video") | |
video_gallery = gr.Gallery( | |
label="Video Gallery", | |
show_label=True, | |
columns=[4], | |
rows=[2], | |
height="auto", | |
object_fit="cover" | |
) | |
# ์ดํ ์ฒซ ๋ฒ์งธ ์ฝ๋์ txt2vid ๊ด๋ จ UI๋ฅผ ํตํฉ | |
# ์ฒซ ๋ฒ์งธ ์ฝ๋์ txt2vid UI๋ฅผ ์ถ๊ฐ ํญ์ผ๋ก ํตํฉ | |
with gr.Tab("Text-to-Video Generation"): | |
with gr.Column(): | |
txt2vid_prompt = gr.Textbox( | |
label="Step 1: Enter Your Prompt (ํ๊ธ ๋๋ ์์ด)", | |
placeholder="์์ฑํ๊ณ ์ถ์ ๋น๋์ค๋ฅผ ์ค๋ช ํ์ธ์ (์ต์ 50์)...", | |
value="๊ธด ๊ฐ์ ๋จธ๋ฆฌ์ ๋ฐ์ ํผ๋ถ๋ฅผ ๊ฐ์ง ์ฌ์ฑ์ด ๊ธด ๊ธ๋ฐ ๋จธ๋ฆฌ๋ฅผ ๊ฐ์ง ๋ค๋ฅธ ์ฌ์ฑ์ ํฅํด ๋ฏธ์ ์ง์ต๋๋ค. ๊ฐ์ ๋จธ๋ฆฌ ์ฌ์ฑ์ ๊ฒ์ ์ฌํท์ ์ ๊ณ ์์ผ๋ฉฐ ์ค๋ฅธ์ชฝ ๋บจ์ ์๊ณ ๊ฑฐ์ ๋์ ๋์ง ์๋ ์ ์ด ์์ต๋๋ค. ์นด๋ฉ๋ผ ์ต๊ธ์ ๊ฐ์ ๋จธ๋ฆฌ ์ฌ์ฑ์ ์ผ๊ตด์ ์ด์ ์ ๋ง์ถ ํด๋ก์ฆ์ ์ ๋๋ค. ์กฐ๋ช ์ ๋ฐ๋ปํ๊ณ ์์ฐ์ค๋ฌ์ฐ๋ฉฐ, ์๋ง๋ ์ง๋ ํด์์ ๋์ค๋ ๊ฒ ๊ฐ์ ์ฅ๋ฉด์ ๋ถ๋๋ฌ์ด ๋น์ ๋น์ถฅ๋๋ค.", | |
lines=5, | |
) | |
txt2vid_enhance_toggle = Toggle( | |
label="Enhance Prompt", | |
value=False, | |
interactive=True, | |
) | |
txt2vid_negative_prompt = gr.Textbox( | |
label="Step 2: Enter Negative Prompt", | |
placeholder="๋น๋์ค์์ ์ํ์ง ์๋ ์์๋ฅผ ์ค๋ช ํ์ธ์...", | |
value="low quality, worst quality, deformed, distorted, damaged, motion blur, motion artifacts, fused fingers, incorrect anatomy, strange hands, ugly", | |
lines=2, | |
) | |
txt2vid_preset = gr.Dropdown( | |
choices=[p["label"] for p in preset_options], | |
value="512x512, 160 frames", | |
label="Step 3.1: Choose Resolution Preset", | |
) | |
txt2vid_frame_rate = gr.Slider( | |
label="Step 3.2: Frame Rate", | |
minimum=6, | |
maximum=60, | |
step=1, | |
value=20, | |
) | |
txt2vid_advanced = create_advanced_options() | |
txt2vid_generate = gr.Button( | |
"Step 5: Generate Video", | |
variant="primary", | |
size="lg", | |
) | |
txt2vid_output = gr.Video(label="Generated Output") | |
txt2vid_preset.change( | |
fn=preset_changed, | |
inputs=[txt2vid_preset], | |
outputs=txt2vid_advanced[3:], | |
) | |
txt2vid_generate.click( | |
fn=generate_video_from_text_90, | |
inputs=[ | |
txt2vid_prompt, | |
txt2vid_enhance_toggle, | |
txt2vid_negative_prompt, | |
txt2vid_frame_rate, | |
*txt2vid_advanced, | |
], | |
outputs=txt2vid_output, | |
concurrency_limit=1, | |
concurrency_id="generate_video", | |
queue=True, | |
) | |
def process_and_save_image(height, width, steps, scales, prompt, seed): | |
is_safe, translated_prompt = process_prompt_for_sd(prompt) | |
if not is_safe: | |
gr.Warning("๋ถ์ ์ ํ ๋ด์ฉ์ด ํฌํจ๋ ํ๋กฌํํธ์ ๋๋ค.") | |
return None, load_gallery() | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"): | |
try: | |
generated_image = pipe_sd( | |
prompt=[translated_prompt], | |
generator=torch.Generator().manual_seed(int(seed)), | |
num_inference_steps=int(steps), | |
guidance_scale=float(scales), | |
height=int(height), | |
width=int(width), | |
max_sequence_length=256 | |
).images[0] | |
if not isinstance(generated_image, Image.Image): | |
generated_image = Image.fromarray(generated_image) | |
if generated_image.mode != 'RGB': | |
generated_image = generated_image.convert('RGB') | |
img_byte_arr = io.BytesIO() | |
generated_image.save(img_byte_arr, format='PNG') | |
img_byte_arr = img_byte_arr.getvalue() | |
saved_path = save_image(generated_image) | |
if saved_path is None: | |
logger.warning("Failed to save generated image") | |
return None, load_gallery() | |
return Image.open(io.BytesIO(img_byte_arr)), load_gallery() | |
except Exception as e: | |
logger.error(f"Error in image generation: {str(e)}") | |
return None, load_gallery() | |
def process_and_generate_video(image, prompt): | |
is_safe, translated_prompt = process_prompt_for_sd(prompt) | |
if not is_safe: | |
gr.Warning("๋ถ์ ์ ํ ๋ด์ฉ์ด ํฌํจ๋ ํ๋กฌํํธ์ ๋๋ค.") | |
return None | |
return generate_video(image, translated_prompt) | |
def update_seed(): | |
return get_random_seed() | |
generate_btn.click( | |
process_and_save_image, | |
inputs=[height, width, steps, scales, img_prompt, seed], | |
outputs=[img_output, img_gallery] | |
) | |
video_generate_btn.click( | |
process_and_generate_video, | |
inputs=[upload_image, video_prompt], | |
outputs=video_output | |
) | |
randomize_seed.click( | |
update_seed, | |
outputs=[seed] | |
) | |
generate_btn.click( | |
update_seed, | |
outputs=[seed] | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=64, default_concurrency_limit=1, api_open=False).launch(share=True, show_api=False, allowed_paths=[PERSISTENT_DIR]) | |