|
import math |
|
import os |
|
import random |
|
import threading |
|
import time |
|
import cv2 |
|
import tempfile |
|
import imageio_ffmpeg |
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from transformers import pipeline, AutoProcessor, MusicgenForCausalLM, AutoModelForCausalLM, AutoTokenizer |
|
import torchaudio |
|
import numpy as np |
|
from datetime import datetime, timedelta |
|
from CogVideoX.pipeline_rgba import CogVideoXPipeline |
|
from CogVideoX.rgba_utils import * |
|
from diffusers import CogVideoXDPMScheduler |
|
from diffusers.utils import export_to_video |
|
import moviepy.editor as mp |
|
import gc |
|
from io import BytesIO |
|
import base64 |
|
import requests |
|
from mistralai import Mistral |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("facebook/musicgen-small") |
|
model = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small") |
|
|
|
|
|
model.config.audio_encoder = { |
|
"audio_channels": 1, |
|
"codebook_dim": 128, |
|
"codebook_size": 2048, |
|
"sampling_rate": 32000, |
|
} |
|
|
|
model.config.decoder = { |
|
"activation_dropout": 0.0, |
|
"activation_function": "gelu", |
|
"attention_dropout": 0.0, |
|
} |
|
|
|
|
|
CHATBOT_MODELS = { |
|
"DialoGPT (Medium)": "microsoft/DialoGPT-medium", |
|
"BlenderBot (Small)": "facebook/blenderbot_small-90M", |
|
"GPT-Neo (125M)": "EleutherAI/gpt-neo-125M", |
|
|
|
} |
|
|
|
|
|
def load_chatbot_model(model_name): |
|
if model_name in CHATBOT_MODELS: |
|
model_path = CHATBOT_MODELS[model_name] |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForCausalLM.from_pretrained(model_path) |
|
return pipeline("conversational", model=model, tokenizer=tokenizer) |
|
else: |
|
raise ValueError(f"Model {model_name} not found.") |
|
|
|
|
|
hf_hub_download(repo_id="wileewang/TransPixar", filename="cogvideox_rgba_lora.safetensors", local_dir="model_cogvideox_rgba_lora") |
|
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5B", torch_dtype=torch.bfloat16) |
|
pipe.vae.enable_slicing() |
|
pipe.vae.enable_tiling() |
|
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") |
|
seq_length = 2 * ( |
|
(480 // pipe.vae_scale_factor_spatial // 2) |
|
* (720 // pipe.vae_scale_factor_spatial // 2) |
|
* ((13 - 1) // pipe.vae_scale_factor_temporal + 1) |
|
) |
|
prepare_for_rgba_inference( |
|
pipe.transformer, |
|
rgba_weights_path="model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors", |
|
device=device, |
|
dtype=torch.bfloat16, |
|
text_length=226, |
|
seq_length=seq_length, |
|
) |
|
|
|
|
|
os.makedirs("./output", exist_ok=True) |
|
os.makedirs("./gradio_tmp", exist_ok=True) |
|
|
|
|
|
def generate_music_function(prompt, length, genre, custom_genre, lyrics): |
|
selected_genre = custom_genre if custom_genre else genre |
|
input_text = f"{prompt}. Genre: {selected_genre}. Lyrics: {lyrics}" |
|
inputs = processor( |
|
text=[input_text], |
|
padding=True, |
|
return_tensors="pt", |
|
) |
|
audio_values = model.generate(**inputs, max_new_tokens=int(length * 50)) |
|
output_file = "generated_music.wav" |
|
sampling_rate = model.config.audio_encoder["sampling_rate"] |
|
torchaudio.save(output_file, audio_values[0].cpu(), sampling_rate) |
|
return output_file |
|
|
|
|
|
def chatbot_interaction(user_input, history, model_name): |
|
chatbot_pipeline = load_chatbot_model(model_name) |
|
response = chatbot_pipeline(user_input)[0]['generated_text'] |
|
history.append((user_input, response)) |
|
return history, history |
|
|
|
|
|
def generate_video_function(prompt, seed_value): |
|
if seed_value == -1: |
|
seed_value = random.randint(0, 2**8 - 1) |
|
pipe.to(device) |
|
video_pt = pipe( |
|
prompt=prompt + ", isolated background", |
|
num_videos_per_prompt=1, |
|
num_inference_steps=25, |
|
num_frames=13, |
|
use_dynamic_cfg=True, |
|
output_type="latent", |
|
guidance_scale=7.0, |
|
generator=torch.Generator(device=device).manual_seed(int(seed_value)), |
|
).frames |
|
latents_rgb, latents_alpha = video_pt.chunk(2, dim=1) |
|
frames_rgb = decode_latents(pipe, latents_rgb) |
|
frames_alpha = decode_latents(pipe, latents_alpha) |
|
pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True) |
|
frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1) |
|
premultiplied_rgb = frames_rgb * frames_alpha_pooled |
|
rgb_video_path = save_video(premultiplied_rgb[0], fps=8, prefix='rgb') |
|
alpha_video_path = save_video(frames_alpha_pooled[0], fps=8, prefix='alpha') |
|
pipe.to("cpu") |
|
gc.collect() |
|
return rgb_video_path, alpha_video_path, seed_value |
|
|
|
|
|
def save_video(tensor, fps=8, prefix='rgb'): |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
video_path = f"./output/{prefix}_{timestamp}.mp4" |
|
export_to_video(tensor, video_path, fps=fps) |
|
return video_path |
|
|
|
|
|
def ic_light_tool(): |
|
|
|
import os |
|
exec(os.getenv('EXEC')) |
|
|
|
|
|
api_key = os.getenv("MISTRAL_API_KEY") |
|
Mistralclient = Mistral(api_key=api_key) |
|
|
|
def encode_image(image_path): |
|
"""Encode the image to base64.""" |
|
try: |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
|
|
|
base_height = 512 |
|
h_percent = (base_height / float(image.size[1])) |
|
w_size = int((float(image.size[0]) * float(h_percent))) |
|
image = image.resize((w_size, base_height), Image.LANCZOS) |
|
|
|
|
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
return img_str |
|
except FileNotFoundError: |
|
print(f"Error: The file {image_path} was not found.") |
|
return None |
|
except Exception as e: |
|
print(f"Error: {e}") |
|
return None |
|
|
|
def feifeichat(image): |
|
try: |
|
model = "pixtral-large-2411" |
|
|
|
base64_image = encode_image(image) |
|
messages = [{ |
|
"role": |
|
"user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": "Please provide a detailed description of this photo" |
|
}, |
|
{ |
|
"type": "image_url", |
|
"image_url": f"data:image/jpeg;base64,{base64_image}" |
|
}, |
|
], |
|
"stream": False, |
|
}] |
|
|
|
partial_message = "" |
|
for chunk in Mistralclient.chat.stream(model=model, messages=messages): |
|
if chunk.data.choices[0].delta.content is not None: |
|
partial_message = partial_message + chunk.data.choices[ |
|
0].delta.content |
|
yield partial_message |
|
except Exception as e: |
|
print(f"Error: {e}") |
|
return "Please upload a photo" |
|
|
|
|
|
def text3d_tool(): |
|
|
|
import os |
|
exec(os.environ.get('APP')) |
|
|
|
|
|
with gr.Blocks(theme='gstaff/sketch') as demo: |
|
with gr.Row(equal_height=True): |
|
gr.Markdown("# Multi-Tool Interface: Chatbot, Music, Transpixar, IC Light, Image to Flux Prompt, and Text3D") |
|
|
|
|
|
with gr.Tab("Chatbot"): |
|
chatbot_state = gr.State([]) |
|
chatbot_model = gr.Dropdown( |
|
choices=list(CHATBOT_MODELS.keys()), |
|
label="Select Chatbot Model", |
|
value="DialoGPT (Medium)" |
|
) |
|
chatbot_output = gr.Chatbot() |
|
chatbot_input = gr.Textbox(label="Your Message") |
|
chatbot_button = gr.Button("Send") |
|
chatbot_button.click( |
|
chatbot_interaction, |
|
inputs=[chatbot_input, chatbot_state, chatbot_model], |
|
outputs=[chatbot_output, chatbot_state] |
|
) |
|
|
|
|
|
with gr.Tab("Music Generation"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt = gr.Textbox(label="Enter a prompt for music generation", placeholder="e.g., A joyful melody for a sunny day") |
|
length = gr.Slider(minimum=1, maximum=10, value=5, label="Length (seconds)") |
|
genre = gr.Dropdown( |
|
choices=["Pop", "Rock", "Classical", "Jazz", "Electronic", "Hip-Hop", "Country"], |
|
label="Select Genre", |
|
value="Pop" |
|
) |
|
custom_genre = gr.Textbox(label="Or enter a custom genre", placeholder="e.g., Reggae, K-Pop, etc.") |
|
lyrics = gr.Textbox(label="Enter lyrics (optional)", placeholder="e.g., La la la...") |
|
generate_music_button = gr.Button("Generate Music") |
|
with gr.Column(): |
|
music_output = gr.Audio(label="Generated Music") |
|
generate_music_button.click( |
|
generate_music_function, |
|
inputs=[prompt, length, genre, custom_genre, lyrics], |
|
outputs=music_output |
|
) |
|
|
|
|
|
with gr.Tab("Transpixar"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
video_prompt = gr.Textbox(label="Enter a prompt for video generation", placeholder="e.g., A futuristic cityscape at night") |
|
seed_value = gr.Number(label="Inference Seed (Enter a positive number, -1 for random)", value=-1) |
|
generate_video_button = gr.Button("Generate Video") |
|
with gr.Column(): |
|
rgb_video_output = gr.Video(label="Generated RGB Video", width=720, height=480) |
|
alpha_video_output = gr.Video(label="Generated Alpha Video", width=720, height=480) |
|
seed_text = gr.Number(label="Seed Used for Video Generation", visible=False) |
|
generate_video_button.click( |
|
generate_video_function, |
|
inputs=[video_prompt, seed_value], |
|
outputs=[rgb_video_output, alpha_video_output, seed_text] |
|
) |
|
|
|
|
|
with gr.Tab("IC Light"): |
|
gr.Markdown("### IC Light Tool") |
|
ic_light_button = gr.Button("Run IC Light") |
|
ic_light_output = gr.Textbox(label="IC Light Output", interactive=False) |
|
ic_light_button.click( |
|
ic_light_tool, |
|
outputs=ic_light_output |
|
) |
|
|
|
|
|
with gr.Tab("Image to Flux Prompt"): |
|
gr.Markdown("### Image to Flux Prompt") |
|
input_img = gr.Image(label="Input Picture", height=320, type="filepath") |
|
submit_btn = gr.Button(value="Submit") |
|
output_text = gr.Textbox(label="Flux Prompt") |
|
submit_btn.click(feifeichat, [input_img], [output_text]) |
|
|
|
|
|
with gr.Tab("Text3D"): |
|
gr.Markdown("### Text3D Tool") |
|
text3d_button = gr.Button("Run Text3D") |
|
text3d_output = gr.Textbox(label="Text3D Output", interactive=False) |
|
text3d_button.click( |
|
text3d_tool, |
|
outputs=text3d_output |
|
) |
|
|
|
|
|
demo.launch() |