Spaces:
Running
on
T4
Running
on
T4
import gradio as gr | |
import torch | |
import cv2 | |
import numpy as np | |
import mediapipe as mp | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionControlNetInpaintPipeline | |
from transformers import AutoTokenizer | |
import base64 | |
import requests | |
import json | |
from rembg import remove | |
from scipy import ndimage | |
from moviepy.editor import ImageSequenceClip | |
from tqdm import tqdm | |
import os | |
import shutil | |
import time | |
from huggingface_hub import snapshot_download | |
import subprocess | |
import sys | |
def download_liveportrait(): | |
""" | |
Clone the LivePortrait repository and prepare its dependencies. | |
""" | |
liveportrait_path = "./LivePortrait" | |
try: | |
if not os.path.exists(liveportrait_path): | |
print("Cloning LivePortrait repository...") | |
os.system(f"git clone https://github.com/KwaiVGI/LivePortrait.git {liveportrait_path}") | |
os.chdir(liveportrait_path) | |
print("Installing LivePortrait dependencies...") | |
os.system("pip install -r requirements.txt") | |
dependency_path = "src/utils/dependencies/XPose/models/UniPose/ops" | |
os.chdir(dependency_path) | |
print("Building MultiScaleDeformableAttention...") | |
os.system("python setup.py build") | |
os.system("python setup.py install") | |
module_path = os.path.abspath(dependency_path) | |
if module_path not in sys.path: | |
sys.path.append(module_path) | |
os.chdir("../../../../../../../") | |
print("LivePortrait setup completed") | |
except Exception as e: | |
print("Failed to initialize LivePortrait:", e) | |
raise | |
download_liveportrait() | |
def download_huggingface_resources(): | |
""" | |
Download additional necessary resources from Hugging Face using the CLI. | |
""" | |
try: | |
local_dir = "./pretrained_weights" | |
os.makedirs(local_dir, exist_ok=True) | |
# Use the Hugging Face CLI for downloading | |
cmd = [ | |
"huggingface-cli", "download", | |
"KwaiVGI/LivePortrait", | |
"--local-dir", local_dir, | |
"--exclude", "*.git*", "README.md", "docs" | |
] | |
print("Executing command:", " ".join(cmd)) | |
subprocess.run(cmd, check=True) | |
print("Resources successfully downloaded to:", local_dir) | |
except subprocess.CalledProcessError as e: | |
print("Error during Hugging Face CLI download:", e) | |
raise | |
except Exception as e: | |
print("General error in downloading resources:", e) | |
raise | |
download_huggingface_resources() | |
def get_project_root(): | |
"""Get the root directory of the current project.""" | |
return os.path.abspath(os.path.dirname(__file__)) | |
# Ensure working directory is project root | |
os.chdir(get_project_root()) | |
# Initialize the necessary models and components | |
mp_pose = mp.solutions.pose | |
mp_drawing = mp.solutions.drawing_utils | |
# Load ControlNet model | |
controlnet = ControlNetModel.from_pretrained('lllyasviel/sd-controlnet-openpose', torch_dtype=torch.float16) | |
# Load Stable Diffusion model with ControlNet | |
pipe_controlnet = StableDiffusionControlNetPipeline.from_pretrained( | |
'runwayml/stable-diffusion-v1-5', | |
controlnet=controlnet, | |
torch_dtype=torch.float16 | |
) | |
# Load Inpaint Controlnet | |
pipe_inpaint_controlnet = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
controlnet=controlnet, | |
torch_dtype=torch.float16 | |
) | |
# Move to GPU if available | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
pipe_controlnet.to(device) | |
pipe_controlnet.enable_attention_slicing() | |
pipe_inpaint_controlnet.to(device) | |
pipe_inpaint_controlnet.enable_attention_slicing() | |
def resize_to_multiple_of_64(width, height): | |
return (width // 64) * 64, (height // 64) * 64 | |
def expand_mask(mask, kernel_size): | |
mask_array = np.array(mask) | |
structuring_element = np.ones((kernel_size, kernel_size), dtype=np.uint8) | |
expanded_mask_array = ndimage.binary_dilation( | |
mask_array, structure=structuring_element | |
).astype(np.uint8) * 255 | |
return Image.fromarray(expanded_mask_array) | |
def crop_face_to_square(image_rgb, padding_ratio=0.2, height_multiplier=1.2): | |
""" | |
Detect the face and crop a rectangular region that includes more of the body below the face. | |
Instead of centering around the face, we start near the face region and extend downward. | |
""" | |
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') | |
gray_image = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY) | |
faces = face_cascade.detectMultiScale(gray_image, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) | |
if len(faces) == 0: | |
print("No face detected.") | |
return None | |
x, y, w, h = faces[0] | |
face_x_center = x + w // 2 | |
face_y_top = y | |
face_side_length = max(w, h) | |
padded_side_length = int(face_side_length * (1 + padding_ratio)) | |
cropped_width = padded_side_length | |
cropped_height = int(padded_side_length * height_multiplier) | |
top_left_x = max(face_x_center - cropped_width // 2, 0) | |
top_margin = int(padded_side_length * 0.1) | |
top_left_y = max(face_y_top - top_margin, 0) | |
bottom_right_x = min(top_left_x + cropped_width, image_rgb.shape[1]) | |
bottom_right_y = min(top_left_y + cropped_height, image_rgb.shape[0]) | |
cropped_image = image_rgb[top_left_y:bottom_right_y, top_left_x:bottom_right_x] | |
return cropped_image | |
def spirit_animal_baseline(image_path, num_images = 4): | |
image = cv2.imread(image_path) | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image_rgb = crop_face_to_square(image_rgb) | |
original_height, original_width, _ = image_rgb.shape | |
aspect_ratio = original_width / original_height | |
if aspect_ratio > 1: | |
gen_width = 768 | |
gen_height = int(gen_width / aspect_ratio) | |
else: | |
gen_height = 768 | |
gen_width = int(gen_height * aspect_ratio) | |
gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height) | |
with mp_pose.Pose(static_image_mode=True) as pose: | |
results = pose.process(image_rgb) | |
if results.pose_landmarks: | |
annotated_image = image_rgb.copy() | |
mp_drawing.draw_landmarks( | |
annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS | |
) | |
else: | |
print("No pose detected.") | |
return "No pose detected.", [] | |
pose_image = np.zeros_like(image_rgb) | |
for connection in mp_pose.POSE_CONNECTIONS: | |
start_idx, end_idx = connection | |
start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx] | |
if start.visibility > 0.5 and end.visibility > 0.5: | |
x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0]) | |
x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0]) | |
cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2) | |
pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4)) | |
base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode() | |
api_key = os.getenv("GPT_KEY") | |
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
payload = { | |
"model": "gpt-4o-mini", | |
"messages": [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "Based on the provided image, think of one spirit animal that is right for the person, and answer in the following format: An ultra-realistic, highly detailed photograph of a single {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate one sentence without any other responses or numbering."}, | |
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} | |
] | |
} | |
], | |
"max_tokens": 100 | |
} | |
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) | |
prompt = response.json()['choices'][0]['message']['content'] if 'choices' in response.json() else "A majestic animal" | |
num_images = num_images | |
generated_images = [] | |
with torch.no_grad(): | |
with torch.autocast(device_type=device.type): | |
for _ in range(num_images): | |
images = pipe_controlnet( | |
prompt=prompt, | |
negative_prompt=( | |
"multiple heads, two heads, double head, triple head, extra limbs, extra arms, extra legs, " | |
"duplicate faces, multiple faces, mutated anatomy, deformed, disfigured, malformed, " | |
"extra ears, fused ears, blurred, low resolution, cartoonish, watermark, text, logo, " | |
"poorly drawn, distorted, floating limbs, out-of-frame" | |
), | |
num_inference_steps=20, | |
image=pose_pil, | |
guidance_scale=5, | |
width=gen_width, | |
height=gen_height, | |
).images | |
generated_images.append(images[0]) | |
return prompt, generated_images | |
def spirit_animal_with_background(image_path, num_images = 4): | |
image = cv2.imread(image_path) | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# image_rgb = crop_face_to_square(image_rgb) | |
original_height, original_width, _ = image_rgb.shape | |
aspect_ratio = original_width / original_height | |
if aspect_ratio > 1: | |
gen_width = 768 | |
gen_height = int(gen_width / aspect_ratio) | |
else: | |
gen_height = 768 | |
gen_width = int(gen_height * aspect_ratio) | |
gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height) | |
with mp_pose.Pose(static_image_mode=True) as pose: | |
results = pose.process(image_rgb) | |
if results.pose_landmarks: | |
annotated_image = image_rgb.copy() | |
mp_drawing.draw_landmarks( | |
annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS | |
) | |
else: | |
print("No pose detected.") | |
return "No pose detected.", [] | |
pose_image = np.zeros_like(image_rgb) | |
for connection in mp_pose.POSE_CONNECTIONS: | |
start_idx, end_idx = connection | |
start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx] | |
if start.visibility > 0.5 and end.visibility > 0.5: | |
x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0]) | |
x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0]) | |
cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2) | |
pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4)) | |
base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode() | |
api_key = os.getenv("GPT_KEY") | |
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
payload = { | |
"model": "gpt-4o-mini", | |
"messages": [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "Based on the provided image, think of one spirit animal that is right for the person, and answer in the following format: An ultra-realistic, highly detailed photograph of a single {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate one sentence without any other responses or numbering."}, | |
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} | |
] | |
} | |
], | |
"max_tokens": 100 | |
} | |
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) | |
prompt = response.json()['choices'][0]['message']['content'] if 'choices' in response.json() else "A majestic animal" | |
mask_image = remove(Image.fromarray(image_rgb)) | |
initial_mask = mask_image.split()[-1].convert('L') | |
kernel_size = min(gen_width, gen_height) // 15 | |
expanded_mask = expand_mask(initial_mask, kernel_size) | |
num_images = num_images | |
generated_images = [] | |
with torch.no_grad(): | |
with torch.autocast(device_type=device.type): | |
for _ in range(num_images): | |
images = pipe_inpaint_controlnet( | |
prompt=prompt, | |
negative_prompt=( | |
"multiple heads, two heads, double head, triple head, extra limbs, extra arms, extra legs, " | |
"duplicate faces, multiple faces, mutated anatomy, deformed, disfigured, malformed, " | |
"extra ears, fused ears, blurred, low resolution, cartoonish, watermark, text, logo, " | |
"poorly drawn, distorted, floating limbs, out-of-frame" | |
), | |
num_inference_steps=20, | |
image=Image.fromarray(image_rgb), | |
mask_image=expanded_mask, | |
control_image=pose_pil, | |
width=gen_width, | |
height=gen_height, | |
guidance_scale=5, | |
).images | |
generated_images.append(images[0]) | |
return prompt, generated_images | |
def generate_multiple_animals(image_path, keep_background=True, num_images = 4, height_multiplier = 1.5): | |
image = cv2.imread(image_path) | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
image_rgb = crop_face_to_square(image_rgb, height_multiplier = height_multiplier) | |
original_image = Image.fromarray(image_rgb) | |
original_width, original_height = original_image.size | |
aspect_ratio = original_width / original_height | |
if aspect_ratio > 1: | |
gen_width = 768 | |
gen_height = int(gen_width / aspect_ratio) | |
else: | |
gen_height = 768 | |
gen_width = int(gen_height * aspect_ratio) | |
gen_width, gen_height = resize_to_multiple_of_64(gen_width, gen_height) | |
base64_image = base64.b64encode(cv2.imencode('.jpg', image_rgb)[1]).decode() | |
api_key = os.getenv("GPT_KEY") | |
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} | |
payload = { | |
"model": "gpt-4o-mini", | |
"messages": [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "Based on the provided image, think of " + str(num_images) + " different spirit animals that are right for the person, and answer in the following format for each: An ultra-realistic, highly detailed photograph of a {animal} with facial features characterized by {description}, standing upright in a human-like pose, looking directly at the camera, against a solid, neutral background. Generate these sentences without any other responses or numbering. For the animal choose between owl, bear, fox, koala, lion, dog" | |
}, | |
{ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"} | |
} | |
] | |
} | |
], | |
"max_tokens": 500 | |
} | |
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) | |
response_json = response.json() | |
if 'choices' in response_json and len(response_json['choices']) > 0: | |
content = response_json['choices'][0]['message']['content'] | |
prompts = [prompt.strip() for prompt in content.strip().split('.') if prompt.strip()] | |
negative_prompt=( | |
"multiple heads, two heads, double head, triple head, extra limbs, extra arms, extra legs, " | |
"duplicate faces, multiple faces, mutated anatomy, deformed, disfigured, malformed, " | |
"extra ears, fused ears, blurred, low resolution, cartoonish, watermark, text, logo, " | |
"poorly drawn, distorted, floating limbs, out-of-frame") | |
formatted_prompts = "\n".join(f"{i+1}. {prompt}" for i, prompt in enumerate(prompts)) | |
with mp_pose.Pose(static_image_mode=True) as pose: | |
results = pose.process(image_rgb) | |
if results.pose_landmarks: | |
annotated_image = image_rgb.copy() | |
mp_drawing.draw_landmarks( | |
annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS | |
) | |
else: | |
print("No pose detected.") | |
return "No pose detected.", [] | |
pose_image = np.zeros_like(image_rgb) | |
for connection in mp_pose.POSE_CONNECTIONS: | |
start_idx, end_idx = connection | |
start, end = results.pose_landmarks.landmark[start_idx], results.pose_landmarks.landmark[end_idx] | |
if start.visibility > 0.5 and end.visibility > 0.5: | |
x1, y1 = int(start.x * pose_image.shape[1]), int(start.y * pose_image.shape[0]) | |
x2, y2 = int(end.x * pose_image.shape[1]), int(end.y * pose_image.shape[0]) | |
cv2.line(pose_image, (x1, y1), (x2, y2), (255, 255, 255), 2) | |
pose_pil = Image.fromarray(cv2.resize(pose_image, (gen_width, gen_height), interpolation=cv2.INTER_LANCZOS4)) | |
if keep_background: | |
mask_image = remove(original_image) | |
initial_mask = mask_image.split()[-1].convert('L') | |
expanded_mask = expand_mask(initial_mask, kernel_size=min(gen_width, gen_height) // 15) | |
else: | |
expanded_mask = None | |
generated_images = [] | |
if keep_background: | |
with torch.no_grad(): | |
with torch.amp.autocast("cuda"): | |
for prompt in prompts: | |
images = pipe_inpaint_controlnet( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=20, | |
image=Image.fromarray(image_rgb), | |
mask_image=expanded_mask, | |
control_image=pose_pil, | |
width=gen_width, | |
height=gen_height, | |
guidance_scale=5, | |
).images | |
generated_images.append(images[0]) | |
else: | |
with torch.no_grad(): | |
with torch.amp.autocast("cuda"): | |
for prompt in prompts: | |
images = pipe_controlnet( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=20, | |
image=pose_pil, | |
guidance_scale=5, | |
width=gen_width, | |
height=gen_height, | |
).images | |
generated_images.append(images[0]) | |
return formatted_prompts, generated_images | |
def wait_for_file(file_path, timeout=500): | |
""" | |
Wait for a file to be created, with a specified timeout. | |
Args: | |
file_path (str): The path of the file to wait for. | |
timeout (int): Maximum time to wait in seconds. | |
Returns: | |
bool: True if the file is created, False if timeout occurs. | |
""" | |
start_time = time.time() | |
while not os.path.exists(file_path): | |
if time.time() - start_time > timeout: | |
return False | |
time.sleep(0.5) # Check every 0.5 seconds | |
return True | |
def generate_spirit_animal_video(driving_video_path): | |
os.chdir(".") | |
try: | |
# Step 1: Extract the first frame | |
cap = cv2.VideoCapture(driving_video_path) | |
if not cap.isOpened(): | |
print("Error: Unable to open video.") | |
return None | |
ret, frame = cap.read() | |
cap.release() | |
if not ret: | |
print("Error: Unable to read the first frame.") | |
return None | |
# Save the first frame | |
first_frame_path = "./first_frame.jpg" | |
cv2.imwrite(first_frame_path, frame) | |
print(f"First frame saved to: {first_frame_path}") | |
# Generate spirit animal image | |
_, input_image = generate_multiple_animals(first_frame_path, True, 1, height_multiplier = 1) | |
if input_image is None or not input_image: | |
print("Error: Spirit animal generation failed.") | |
return None | |
spirit_animal_path = "./animal.jpeg" | |
cv2.imwrite(spirit_animal_path, cv2.cvtColor(np.array(input_image[0]), cv2.COLOR_RGB2BGR)) | |
print(f"Spirit animal image saved to: {spirit_animal_path}") | |
# Step 3: Run inference | |
output_path = "./animations/animal--uploaded_video_compressed.mp4" | |
script_path = os.path.abspath("./LivePortrait/inference_animals.py") | |
if not os.path.exists(script_path): | |
print(f"Error: Inference script not found at {script_path}.") | |
return None | |
command = f"python {script_path} -s {spirit_animal_path} -d {driving_video_path} --driving_multiplier 1.75 --no_flag_stitching" | |
print(f"Running command: {command}") | |
result = os.system(command) | |
if result != 0: | |
print(f"Error: Command failed with exit code {result}.") | |
return None | |
# Verify output file exists | |
if not os.path.exists(output_path): | |
print(f"Error: Expected output video not found at {output_path}.") | |
return None | |
print(f"Output video generated at: {output_path}") | |
return output_path | |
except Exception as e: | |
print(f"Error occurred: {e}") | |
return None | |
def generate_spirit_animal(image, animal_type, background): | |
if animal_type == "Single Animal": | |
if background == "Preserve Background": | |
prompt, generated_images = spirit_animal_with_background(image) | |
else: | |
prompt, generated_images = spirit_animal_baseline(image) | |
elif animal_type == "Multiple Animals": | |
if background == "Preserve Background": | |
prompt, generated_images = generate_multiple_animals(image, keep_background=True) | |
else: | |
prompt, generated_images = generate_multiple_animals(image, keep_background=False) | |
return prompt, generated_images | |
def compress_video(input_path, output_path, target_size_mb): | |
target_size_bytes = target_size_mb * 1024 * 1024 | |
temp_output = "./temp_compressed.mp4" | |
cap = cv2.VideoCapture(input_path) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
writer = cv2.VideoWriter(temp_output, fourcc, fps, (width, height)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
writer.write(frame) | |
cap.release() | |
writer.release() | |
current_size = os.path.getsize(temp_output) | |
if current_size > target_size_bytes: | |
bitrate = int(target_size_bytes * 8 / (current_size / target_size_bytes)) | |
os.system(f"ffmpeg -i {temp_output} -b:v {bitrate} -y {output_path}") | |
os.remove(temp_output) | |
else: | |
shutil.move(temp_output, output_path) | |
def process_video(video_file): | |
compressed_path = "./uploaded_video_compressed.mp4" | |
compress_video(video_file, compressed_path, target_size_mb=1) | |
print(f"Compressed and moved video to: {compressed_path}") | |
output_video_path = "./animations/animal--uploaded_video_compressed.mp4" | |
generate_spirit_animal_video(compressed_path) | |
# Wait until the output video is generated | |
timeout = 1000 # Timeout in seconds | |
if not wait_for_file(output_video_path, timeout=timeout): | |
print("Timeout occurred while waiting for video generation.") | |
return gr.update(value=None, visible=False) # Hide output if failed | |
# Return the generated video path | |
print(f"Output video is ready: {output_video_path}") | |
return gr.update(value=output_video_path, visible=True) # Show video | |
css = """ | |
#title-container { | |
font-family: 'Arial', sans-serif; | |
color: #4a4a4a; | |
text-align: center; | |
margin-bottom: 20px; | |
} | |
#title-container h1 { | |
font-size: 2.5em; | |
font-weight: bold; | |
color: #ff9900; | |
} | |
#title-container h2 { | |
font-size: 1.2em; | |
color: #6c757d; | |
} | |
#intro-text { | |
font-size: 1em; | |
color: #6c757d; | |
margin: 50px; | |
text-align: center; | |
font-style: italic; | |
} | |
#prompt-output { | |
font-family: 'Courier New', monospace; | |
color: #5a5a5a; | |
font-size: 1.1em; | |
padding: 10px; | |
background-color: #f9f9f9; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
margin-top: 10px; | |
} | |
.examples-container { | |
display: flex; | |
flex-wrap: wrap; | |
gap: 10px; | |
justify-content: center; | |
align-items: flex-start; | |
} | |
""" | |
# Title and description | |
title_html = """ | |
<div id="title-container"> | |
<h1>Spirit Animal Generator</h1> | |
<h2>Create your unique spirit animal with AI-assisted image generation.</h2> | |
</div> | |
""" | |
description_text = """ | |
### Project Overview | |
Welcome to the Spirit Animal Generator! This tool leverages Stable Diffusion models to create unique visualizations of spirit animals from videos and images. | |
#### Key Features: | |
1. **Prompting**: [GPT Model](https://arxiv.org/abs/2305.10435) generates descriptive prompts for each media input. | |
2. **Image Creation**: [ControlNet Model](https://arxiv.org/abs/2302.05543) generates animal images with pose control. | |
3. **Video Transformation**: [LivePortrait Model](https://arxiv.org/abs/2407.03168) generate animal animation with same facial expressions. | |
--- | |
### How It Works: | |
1. **Upload Your Media**: | |
- Images: Use clear, high-resolution photos for better results. | |
- Videos: Ensure the file is in MP4 format. | |
2. **Customize Options**: | |
- For images, select the type of animal and background settings. | |
3. **View Your Results**: | |
- Images will produce customized visual art along with a generated prompt. | |
- Videos will be transformed into animal animations. | |
Discover your spirit animal and let your imagination run wild! | |
--- | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(title_html) | |
gr.Markdown(description_text) | |
with gr.Tabs(): | |
with gr.Tab("Generate Spirit Animal Image"): | |
gr.Markdown("Upload an image to generate a spirit animal.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="filepath", label="Upload an image") | |
animal_type = gr.Radio(choices=["Single Animal", "Multiple Animals"], label="Animal Type", value="Single Animal") | |
background_option = gr.Radio(choices=["Preserve Background", "Don't Preserve Background"], label="Background Option", value="Preserve Background") | |
generate_image_button = gr.Button("Generate Image") | |
gr.Examples( | |
examples=["example1.jpg", "example2.jpg", "example3.jpg"], | |
inputs=image_input, | |
label="Example Images" | |
) | |
with gr.Column(scale=1): | |
generated_prompt = gr.Textbox(label="Generated Prompt") | |
generated_gallery = gr.Gallery(label="Generated Images") | |
generate_image_button.click( | |
fn=generate_spirit_animal, | |
inputs=[image_input, animal_type, background_option], | |
outputs=[generated_prompt, generated_gallery], | |
) | |
with gr.Tab("Generate Spirit Animal Video"): | |
gr.Markdown("Upload a driving video to generate a spirit animal video.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
video_input = gr.Video(label="Upload a driving video (MP4 format)") | |
generate_video_button = gr.Button("Generate Video") | |
gr.Examples( | |
examples=["video1.mp4", "video3.mp4", "video4.mp4"], | |
inputs=video_input, | |
label="Example Videos" | |
) | |
with gr.Column(scale=1): | |
video_output = gr.Video(label="Generated Spirit Animal Video") | |
generate_video_button.click( | |
fn=process_video, | |
inputs=video_input, | |
outputs=video_output, | |
) | |
demo.launch() |