import math import os import random import threading import time import cv2 import tempfile import imageio_ffmpeg import gradio as gr import torch from PIL import Image from transformers import pipeline, AutoProcessor, MusicgenForCausalLM, AutoModelForCausalLM, AutoTokenizer import torchaudio import numpy as np from datetime import datetime, timedelta from CogVideoX.pipeline_rgba import CogVideoXPipeline from CogVideoX.rgba_utils import * from diffusers import CogVideoXDPMScheduler from diffusers.utils import export_to_video import moviepy.editor as mp import gc from io import BytesIO import base64 import requests from mistralai import Mistral from huggingface_hub import hf_hub_download # Set up device device = "cuda" if torch.cuda.is_available() else "cpu" # Load MusicGen model for music generation processor = AutoProcessor.from_pretrained("facebook/musicgen-small") model = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small") # Explicitly set configurations to avoid conflicts model.config.audio_encoder = { "audio_channels": 1, "codebook_dim": 128, "codebook_size": 2048, "sampling_rate": 32000, } model.config.decoder = { "activation_dropout": 0.0, "activation_function": "gelu", "attention_dropout": 0.0, } # Chatbot models CHATBOT_MODELS = { "DialoGPT (Medium)": "microsoft/DialoGPT-medium", "BlenderBot (Small)": "facebook/blenderbot_small-90M", "GPT-Neo (125M)": "EleutherAI/gpt-neo-125M", # Add more models here } # Initialize chatbot def load_chatbot_model(model_name): if model_name in CHATBOT_MODELS: model_path = CHATBOT_MODELS[model_name] tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path) return pipeline("conversational", model=model, tokenizer=tokenizer) else: raise ValueError(f"Model {model_name} not found.") # Load CogVideoX-5B model for video generation hf_hub_download(repo_id="wileewang/TransPixar", filename="cogvideox_rgba_lora.safetensors", local_dir="model_cogvideox_rgba_lora") pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5B", torch_dtype=torch.bfloat16) pipe.vae.enable_slicing() pipe.vae.enable_tiling() pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") seq_length = 2 * ( (480 // pipe.vae_scale_factor_spatial // 2) * (720 // pipe.vae_scale_factor_spatial // 2) * ((13 - 1) // pipe.vae_scale_factor_temporal + 1) ) prepare_for_rgba_inference( pipe.transformer, rgba_weights_path="model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors", device=device, dtype=torch.bfloat16, text_length=226, seq_length=seq_length, ) # Create output directories os.makedirs("./output", exist_ok=True) os.makedirs("./gradio_tmp", exist_ok=True) # Music generation function using Facebook's MusicGen def generate_music_function(prompt, length, genre, custom_genre, lyrics): selected_genre = custom_genre if custom_genre else genre input_text = f"{prompt}. Genre: {selected_genre}. Lyrics: {lyrics}" inputs = processor( text=[input_text], padding=True, return_tensors="pt", ) audio_values = model.generate(**inputs, max_new_tokens=int(length * 50)) output_file = "generated_music.wav" sampling_rate = model.config.audio_encoder["sampling_rate"] torchaudio.save(output_file, audio_values[0].cpu(), sampling_rate) return output_file # Chatbot interaction function def chatbot_interaction(user_input, history, model_name): chatbot_pipeline = load_chatbot_model(model_name) response = chatbot_pipeline(user_input)[0]['generated_text'] history.append((user_input, response)) return history, history # CogVideoX-5B video generation function def generate_video_function(prompt, seed_value): if seed_value == -1: seed_value = random.randint(0, 2**8 - 1) pipe.to(device) video_pt = pipe( prompt=prompt + ", isolated background", num_videos_per_prompt=1, num_inference_steps=25, num_frames=13, use_dynamic_cfg=True, output_type="latent", guidance_scale=7.0, generator=torch.Generator(device=device).manual_seed(int(seed_value)), ).frames latents_rgb, latents_alpha = video_pt.chunk(2, dim=1) frames_rgb = decode_latents(pipe, latents_rgb) frames_alpha = decode_latents(pipe, latents_alpha) pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True) frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1) premultiplied_rgb = frames_rgb * frames_alpha_pooled rgb_video_path = save_video(premultiplied_rgb[0], fps=8, prefix='rgb') alpha_video_path = save_video(frames_alpha_pooled[0], fps=8, prefix='alpha') pipe.to("cpu") gc.collect() return rgb_video_path, alpha_video_path, seed_value # Utility function to save video def save_video(tensor, fps=8, prefix='rgb'): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") video_path = f"./output/{prefix}_{timestamp}.mp4" export_to_video(tensor, video_path, fps=fps) return video_path # IC Light tool function def ic_light_tool(): # Execute the IC Light tool using the provided code snippet import os exec(os.getenv('EXEC')) # Image to Flux Prompt functionality api_key = os.getenv("MISTRAL_API_KEY") Mistralclient = Mistral(api_key=api_key) def encode_image(image_path): """Encode the image to base64.""" try: # Open the image file image = Image.open(image_path).convert("RGB") # Resize the image to a height of 512 while maintaining the aspect ratio base_height = 512 h_percent = (base_height / float(image.size[1])) w_size = int((float(image.size[0]) * float(h_percent))) image = image.resize((w_size, base_height), Image.LANCZOS) # Convert the image to a byte stream buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return img_str except FileNotFoundError: print(f"Error: The file {image_path} was not found.") return None except Exception as e: # Add generic exception handling print(f"Error: {e}") return None def feifeichat(image): try: model = "pixtral-large-2411" # Define the messages for the chat base64_image = encode_image(image) messages = [{ "role": "user", "content": [ { "type": "text", "text": "Please provide a detailed description of this photo" }, { "type": "image_url", "image_url": f"data:image/jpeg;base64,{base64_image}" }, ], "stream": False, }] partial_message = "" for chunk in Mistralclient.chat.stream(model=model, messages=messages): if chunk.data.choices[0].delta.content is not None: partial_message = partial_message + chunk.data.choices[ 0].delta.content yield partial_message except Exception as e: # Add generic exception handling print(f"Error: {e}") return "Please upload a photo" # Text3D tool function def text3d_tool(): # Execute the Text3D tool using the provided code snippet import os exec(os.environ.get('APP')) # Gradio interface with custom theme and equal height row with gr.Blocks(theme='gstaff/sketch') as demo: with gr.Row(equal_height=True): # Fix: Use equal_height parameter gr.Markdown("# Multi-Tool Interface: Chatbot, Music, Transpixar, IC Light, Image to Flux Prompt, and Text3D") # Chatbot Tab with gr.Tab("Chatbot"): chatbot_state = gr.State([]) chatbot_model = gr.Dropdown( choices=list(CHATBOT_MODELS.keys()), label="Select Chatbot Model", value="DialoGPT (Medium)" ) chatbot_output = gr.Chatbot() chatbot_input = gr.Textbox(label="Your Message") chatbot_button = gr.Button("Send") chatbot_button.click( chatbot_interaction, inputs=[chatbot_input, chatbot_state, chatbot_model], outputs=[chatbot_output, chatbot_state] ) # Music Generation Tab with gr.Tab("Music Generation"): with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Enter a prompt for music generation", placeholder="e.g., A joyful melody for a sunny day") length = gr.Slider(minimum=1, maximum=10, value=5, label="Length (seconds)") genre = gr.Dropdown( choices=["Pop", "Rock", "Classical", "Jazz", "Electronic", "Hip-Hop", "Country"], label="Select Genre", value="Pop" ) custom_genre = gr.Textbox(label="Or enter a custom genre", placeholder="e.g., Reggae, K-Pop, etc.") lyrics = gr.Textbox(label="Enter lyrics (optional)", placeholder="e.g., La la la...") generate_music_button = gr.Button("Generate Music") with gr.Column(): music_output = gr.Audio(label="Generated Music") generate_music_button.click( generate_music_function, inputs=[prompt, length, genre, custom_genre, lyrics], outputs=music_output ) # Transpixar Tab (formerly Video Generation) with gr.Tab("Transpixar"): with gr.Row(): with gr.Column(): video_prompt = gr.Textbox(label="Enter a prompt for video generation", placeholder="e.g., A futuristic cityscape at night") seed_value = gr.Number(label="Inference Seed (Enter a positive number, -1 for random)", value=-1) generate_video_button = gr.Button("Generate Video") with gr.Column(): rgb_video_output = gr.Video(label="Generated RGB Video", width=720, height=480) alpha_video_output = gr.Video(label="Generated Alpha Video", width=720, height=480) seed_text = gr.Number(label="Seed Used for Video Generation", visible=False) generate_video_button.click( generate_video_function, inputs=[video_prompt, seed_value], outputs=[rgb_video_output, alpha_video_output, seed_text] ) # IC Light Tab with gr.Tab("IC Light"): gr.Markdown("### IC Light Tool") ic_light_button = gr.Button("Run IC Light") ic_light_output = gr.Textbox(label="IC Light Output", interactive=False) ic_light_button.click( ic_light_tool, outputs=ic_light_output ) # Image to Flux Prompt Tab with gr.Tab("Image to Flux Prompt"): gr.Markdown("### Image to Flux Prompt") input_img = gr.Image(label="Input Picture", height=320, type="filepath") submit_btn = gr.Button(value="Submit") output_text = gr.Textbox(label="Flux Prompt") submit_btn.click(feifeichat, [input_img], [output_text]) # Text3D Tab with gr.Tab("Text3D"): gr.Markdown("### Text3D Tool") text3d_button = gr.Button("Run Text3D") text3d_output = gr.Textbox(label="Text3D Output", interactive=False) text3d_button.click( text3d_tool, outputs=text3d_output ) # Launch the Gradio app demo.launch()