Felguk commited on
Commit
2afded6
·
verified ·
1 Parent(s): 785289e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +294 -0
app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import threading
5
+ import time
6
+ import cv2
7
+ import tempfile
8
+ import imageio_ffmpeg
9
+ import gradio as gr
10
+ import torch
11
+ from PIL import Image
12
+ from transformers import pipeline, AutoProcessor, MusicgenForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
13
+ import torchaudio
14
+ import numpy as np
15
+ from datetime import datetime, timedelta
16
+ from CogVideoX.pipeline_rgba import CogVideoXPipeline
17
+ from CogVideoX.rgba_utils import *
18
+ from diffusers import CogVideoXDPMScheduler
19
+ from diffusers.utils import export_to_video
20
+ import moviepy.editor as mp
21
+ import gc
22
+ from io import BytesIO
23
+ import base64
24
+ import requests
25
+ from mistralai import Mistral
26
+
27
+ # Set up device
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ # Load MusicGen model for music generation
31
+ processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
32
+ musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
33
+
34
+ # Chatbot models
35
+ CHATBOT_MODELS = {
36
+ "DialoGPT (Medium)": "microsoft/DialoGPT-medium",
37
+ "BlenderBot (Small)": "facebook/blenderbot_small-90M",
38
+ "GPT-Neo (125M)": "EleutherAI/gpt-neo-125M",
39
+ # Add more models here
40
+ }
41
+
42
+ # Initialize chatbot
43
+ def load_chatbot_model(model_name):
44
+ if model_name in CHATBOT_MODELS:
45
+ model_path = CHATBOT_MODELS[model_name]
46
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
47
+ model = AutoModelForCausalLM.from_pretrained(model_path)
48
+ return pipeline("conversational", model=model, tokenizer=tokenizer)
49
+ else:
50
+ raise ValueError(f"Model {model_name} not found.")
51
+
52
+ # Load CogVideoX-5B model for video generation
53
+ hf_hub_download(repo_id="wileewang/TransPixar", filename="cogvideox_rgba_lora.safetensors", local_dir="model_cogvideox_rgba_lora")
54
+ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5B", torch_dtype=torch.bfloat16)
55
+ pipe.vae.enable_slicing()
56
+ pipe.vae.enable_tiling()
57
+ pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
58
+ seq_length = 2 * (
59
+ (480 // pipe.vae_scale_factor_spatial // 2)
60
+ * (720 // pipe.vae_scale_factor_spatial // 2)
61
+ * ((13 - 1) // pipe.vae_scale_factor_temporal + 1)
62
+ )
63
+ prepare_for_rgba_inference(
64
+ pipe.transformer,
65
+ rgba_weights_path="model_cogvideox_rgba_lora/cogvideox_rgba_lora.safetensors",
66
+ device=device,
67
+ dtype=torch.bfloat16,
68
+ text_length=226,
69
+ seq_length=seq_length,
70
+ )
71
+
72
+ # Create output directories
73
+ os.makedirs("./output", exist_ok=True)
74
+ os.makedirs("./gradio_tmp", exist_ok=True)
75
+
76
+ # Music generation function using Facebook's MusicGen
77
+ def generate_music_function(prompt, length, genre, custom_genre, lyrics):
78
+ selected_genre = custom_genre if custom_genre else genre
79
+ input_text = f"{prompt}. Genre: {selected_genre}. Lyrics: {lyrics}"
80
+ inputs = processor(
81
+ text=[input_text],
82
+ padding=True,
83
+ return_tensors="pt",
84
+ )
85
+ audio_values = musicgen_model.generate(**inputs, max_new_tokens=int(length * 50))
86
+ output_file = "generated_music.wav"
87
+ sampling_rate = musicgen_model.config.audio_encoder.sampling_rate
88
+ torchaudio.save(output_file, audio_values[0].cpu(), sampling_rate)
89
+ return output_file
90
+
91
+ # Chatbot interaction function
92
+ def chatbot_interaction(user_input, history, model_name):
93
+ chatbot_pipeline = load_chatbot_model(model_name)
94
+ response = chatbot_pipeline(user_input)[0]['generated_text']
95
+ history.append((user_input, response))
96
+ return history, history
97
+
98
+ # CogVideoX-5B video generation function
99
+ def generate_video_function(prompt, seed_value):
100
+ if seed_value == -1:
101
+ seed_value = random.randint(0, 2**8 - 1)
102
+ pipe.to(device)
103
+ video_pt = pipe(
104
+ prompt=prompt + ", isolated background",
105
+ num_videos_per_prompt=1,
106
+ num_inference_steps=25,
107
+ num_frames=13,
108
+ use_dynamic_cfg=True,
109
+ output_type="latent",
110
+ guidance_scale=7.0,
111
+ generator=torch.Generator(device=device).manual_seed(int(seed_value)),
112
+ ).frames
113
+ latents_rgb, latents_alpha = video_pt.chunk(2, dim=1)
114
+ frames_rgb = decode_latents(pipe, latents_rgb)
115
+ frames_alpha = decode_latents(pipe, latents_alpha)
116
+ pooled_alpha = np.max(frames_alpha, axis=-1, keepdims=True)
117
+ frames_alpha_pooled = np.repeat(pooled_alpha, 3, axis=-1)
118
+ premultiplied_rgb = frames_rgb * frames_alpha_pooled
119
+ rgb_video_path = save_video(premultiplied_rgb[0], fps=8, prefix='rgb')
120
+ alpha_video_path = save_video(frames_alpha_pooled[0], fps=8, prefix='alpha')
121
+ pipe.to("cpu")
122
+ gc.collect()
123
+ return rgb_video_path, alpha_video_path, seed_value
124
+
125
+ # Utility function to save video
126
+ def save_video(tensor, fps=8, prefix='rgb'):
127
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
128
+ video_path = f"./output/{prefix}_{timestamp}.mp4"
129
+ export_to_video(tensor, video_path, fps=fps)
130
+ return video_path
131
+
132
+ # IC Light tool function
133
+ def ic_light_tool():
134
+ # Execute the IC Light tool using the provided code snippet
135
+ import os
136
+ exec(os.getenv('EXEC'))
137
+
138
+ # Image to Flux Prompt functionality
139
+ api_key = os.getenv("MISTRAL_API_KEY")
140
+ Mistralclient = Mistral(api_key=api_key)
141
+
142
+ def encode_image(image_path):
143
+ """Encode the image to base64."""
144
+ try:
145
+ # Open the image file
146
+ image = Image.open(image_path).convert("RGB")
147
+
148
+ # Resize the image to a height of 512 while maintaining the aspect ratio
149
+ base_height = 512
150
+ h_percent = (base_height / float(image.size[1]))
151
+ w_size = int((float(image.size[0]) * float(h_percent)))
152
+ image = image.resize((w_size, base_height), Image.LANCZOS)
153
+
154
+ # Convert the image to a byte stream
155
+ buffered = BytesIO()
156
+ image.save(buffered, format="JPEG")
157
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
158
+
159
+ return img_str
160
+ except FileNotFoundError:
161
+ print(f"Error: The file {image_path} was not found.")
162
+ return None
163
+ except Exception as e: # Add generic exception handling
164
+ print(f"Error: {e}")
165
+ return None
166
+
167
+ def feifeichat(image):
168
+ try:
169
+ model = "pixtral-large-2411"
170
+ # Define the messages for the chat
171
+ base64_image = encode_image(image)
172
+ messages = [{
173
+ "role":
174
+ "user",
175
+ "content": [
176
+ {
177
+ "type": "text",
178
+ "text": "Please provide a detailed description of this photo"
179
+ },
180
+ {
181
+ "type": "image_url",
182
+ "image_url": f"data:image/jpeg;base64,{base64_image}"
183
+ },
184
+ ],
185
+ "stream": False,
186
+ }]
187
+
188
+ partial_message = ""
189
+ for chunk in Mistralclient.chat.stream(model=model, messages=messages):
190
+ if chunk.data.choices[0].delta.content is not None:
191
+ partial_message = partial_message + chunk.data.choices[
192
+ 0].delta.content
193
+ yield partial_message
194
+ except Exception as e: # Add generic exception handling
195
+ print(f"Error: {e}")
196
+ return "Please upload a photo"
197
+
198
+ # Text3D tool function
199
+ def text3d_tool():
200
+ # Execute the Text3D tool using the provided code snippet
201
+ import os
202
+ exec(os.environ.get('APP'))
203
+
204
+ # Gradio interface with custom theme and equal height row
205
+ with gr.Blocks(theme='gstaff/sketch') as demo:
206
+ with gr.Row().style(equal_height=True):
207
+ gr.Markdown("# Multi-Tool Interface: Chatbot, Music, Transpixar, IC Light, Image to Flux Prompt, and Text3D")
208
+
209
+ # Chatbot Tab
210
+ with gr.Tab("Chatbot"):
211
+ chatbot_state = gr.State([])
212
+ chatbot_model = gr.Dropdown(
213
+ choices=list(CHATBOT_MODELS.keys()),
214
+ label="Select Chatbot Model",
215
+ value="DialoGPT (Medium)"
216
+ )
217
+ chatbot_output = gr.Chatbot()
218
+ chatbot_input = gr.Textbox(label="Your Message")
219
+ chatbot_button = gr.Button("Send")
220
+ chatbot_button.click(
221
+ chatbot_interaction,
222
+ inputs=[chatbot_input, chatbot_state, chatbot_model],
223
+ outputs=[chatbot_output, chatbot_state]
224
+ )
225
+
226
+ # Music Generation Tab
227
+ with gr.Tab("Music Generation"):
228
+ with gr.Row():
229
+ with gr.Column():
230
+ prompt = gr.Textbox(label="Enter a prompt for music generation", placeholder="e.g., A joyful melody for a sunny day")
231
+ length = gr.Slider(minimum=1, maximum=10, value=5, label="Length (seconds)")
232
+ genre = gr.Dropdown(
233
+ choices=["Pop", "Rock", "Classical", "Jazz", "Electronic", "Hip-Hop", "Country"],
234
+ label="Select Genre",
235
+ value="Pop"
236
+ )
237
+ custom_genre = gr.Textbox(label="Or enter a custom genre", placeholder="e.g., Reggae, K-Pop, etc.")
238
+ lyrics = gr.Textbox(label="Enter lyrics (optional)", placeholder="e.g., La la la...")
239
+ generate_music_button = gr.Button("Generate Music")
240
+ with gr.Column():
241
+ music_output = gr.Audio(label="Generated Music")
242
+ generate_music_button.click(
243
+ generate_music_function,
244
+ inputs=[prompt, length, genre, custom_genre, lyrics],
245
+ outputs=music_output
246
+ )
247
+
248
+ # Transpixar Tab (formerly Video Generation)
249
+ with gr.Tab("Transpixar"):
250
+ with gr.Row():
251
+ with gr.Column():
252
+ video_prompt = gr.Textbox(label="Enter a prompt for video generation", placeholder="e.g., A futuristic cityscape at night")
253
+ seed_value = gr.Number(label="Inference Seed (Enter a positive number, -1 for random)", value=-1)
254
+ generate_video_button = gr.Button("Generate Video")
255
+ with gr.Column():
256
+ rgb_video_output = gr.Video(label="Generated RGB Video", width=720, height=480)
257
+ alpha_video_output = gr.Video(label="Generated Alpha Video", width=720, height=480)
258
+ seed_text = gr.Number(label="Seed Used for Video Generation", visible=False)
259
+ generate_video_button.click(
260
+ generate_video_function,
261
+ inputs=[video_prompt, seed_value],
262
+ outputs=[rgb_video_output, alpha_video_output, seed_text]
263
+ )
264
+
265
+ # IC Light Tab
266
+ with gr.Tab("IC Light"):
267
+ gr.Markdown("### IC Light Tool")
268
+ ic_light_button = gr.Button("Run IC Light")
269
+ ic_light_output = gr.Textbox(label="IC Light Output", interactive=False)
270
+ ic_light_button.click(
271
+ ic_light_tool,
272
+ outputs=ic_light_output
273
+ )
274
+
275
+ # Image to Flux Prompt Tab
276
+ with gr.Tab("Image to Flux Prompt"):
277
+ gr.Markdown("### Image to Flux Prompt")
278
+ input_img = gr.Image(label="Input Picture", height=320, type="filepath")
279
+ submit_btn = gr.Button(value="Submit")
280
+ output_text = gr.Textbox(label="Flux Prompt")
281
+ submit_btn.click(feifeichat, [input_img], [output_text])
282
+
283
+ # Text3D Tab
284
+ with gr.Tab("Text3D"):
285
+ gr.Markdown("### Text3D Tool")
286
+ text3d_button = gr.Button("Run Text3D")
287
+ text3d_output = gr.Textbox(label="Text3D Output", interactive=False)
288
+ text3d_button.click(
289
+ text3d_tool,
290
+ outputs=text3d_output
291
+ )
292
+
293
+ # Launch the Gradio app
294
+ demo.launch()