|
import torch |
|
from diffusers import DiffusionPipeline, StableDiffusionImageVariationPipeline |
|
from PIL import Image |
|
import numpy as np |
|
import cv2 |
|
|
|
class BootyShakerPipeline: |
|
def __init__(self): |
|
self.txt2video_pipe = DiffusionPipeline.from_pretrained("ChromiumPlutoniumAI/BootyShakerAI") |
|
self.img2video_pipe = StableDiffusionImageVariationPipeline.from_pretrained("ChromiumPlutoniumAI/BootyShakerAI") |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def generate_from_text(self, prompt, num_frames=16, fps=8): |
|
video = self.txt2video_pipe( |
|
prompt, |
|
num_inference_steps=50, |
|
num_frames=num_frames |
|
).frames |
|
return self.frames_to_video(video, fps) |
|
|
|
def generate_from_image(self, image, num_frames=16, fps=8): |
|
if isinstance(image, str): |
|
image = Image.open(image) |
|
video = self.img2video_pipe( |
|
image, |
|
num_inference_steps=50, |
|
num_frames=num_frames |
|
).frames |
|
return self.frames_to_video(video, fps) |
|
|
|
def apply_alterations(self, video, style="bounce"): |
|
styles = { |
|
"bounce": self.bounce_effect, |
|
"wave": self.wave_effect, |
|
"shake": self.shake_effect |
|
} |
|
return styles[style](video) |
|
|
|
@staticmethod |
|
def frames_to_video(frames, fps): |
|
output_file = "output.mp4" |
|
writer = cv2.VideoWriter( |
|
output_file, |
|
cv2.VideoWriter_fourcc(*"mp4v"), |
|
fps, |
|
(frames[0].shape[1], frames[0].shape[0]) |
|
) |
|
for frame in frames: |
|
writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) |
|
writer.release() |
|
return output_file |
|
|