fantaxy commited on
Commit
ca6d53d
ยท
verified ยท
1 Parent(s): ee64981

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -0
app.py CHANGED
@@ -25,7 +25,289 @@ import gc
25
  import csv
26
  from datetime import datetime
27
  from openai import OpenAI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # ํ•œ๊ธ€-์˜์–ด ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
30
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
31
 
 
25
  import csv
26
  from datetime import datetime
27
  from openai import OpenAI
28
+ import spaces
29
+ import argparse
30
+
31
+ import time
32
+ from os import path
33
+ import shutil
34
+ from datetime import datetime
35
+ from safetensors.torch import load_file
36
+ from huggingface_hub import hf_hub_download
37
+ import gradio as gr
38
+ import torch
39
+ from diffusers import FluxPipeline
40
+ from diffusers.pipelines.stable_diffusion import safety_checker
41
+ from PIL import Image
42
+ from transformers import pipeline
43
+ import replicate
44
+ import logging
45
+ import requests
46
+ from pathlib import Path
47
+ import cv2
48
+ import numpy as np
49
+ import sys
50
+ import io
51
+
52
+ logging.basicConfig(level=logging.INFO)
53
+ logger = logging.getLogger(__name__)
54
+
55
+ # Setup and initialization code
56
+ cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
57
+ PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
58
+ gallery_path = path.join(PERSISTENT_DIR, "gallery")
59
+ video_gallery_path = path.join(PERSISTENT_DIR, "video_gallery")
60
+
61
+ # API ์„ค์ •
62
+ CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5"
63
+ REPLICATE_API_TOKEN = os.getenv("API_KEY")
64
+
65
+ # ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
66
+ os.environ["TRANSFORMERS_CACHE"] = cache_path
67
+ os.environ["HF_HUB_CACHE"] = cache_path
68
+ os.environ["HF_HOME"] = cache_path
69
+
70
+ # CUDA ์„ค์ •
71
+ torch.backends.cuda.matmul.allow_tf32 = True
72
+
73
+ # ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
74
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
75
+
76
+ # ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
77
+ for dir_path in [gallery_path, video_gallery_path]:
78
+ if not path.exists(dir_path):
79
+ os.makedirs(dir_path, exist_ok=True)
80
+
81
+ def check_api_key():
82
+ """API ํ‚ค ํ™•์ธ ๋ฐ ์„ค์ •"""
83
+ if not REPLICATE_API_TOKEN:
84
+ logger.error("Replicate API key not found")
85
+ return False
86
+ os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN
87
+ logger.info("Replicate API token set successfully")
88
+ return True
89
+
90
+ def translate_if_korean(text):
91
+ """ํ•œ๊ธ€์ด ํฌํ•จ๋œ ๊ฒฝ์šฐ ์˜์–ด๋กœ ๋ฒˆ์—ญ"""
92
+ if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text):
93
+ translation = translator(text)[0]['translation_text']
94
+ return translation
95
+ return text
96
+
97
+ def filter_prompt(prompt):
98
+ inappropriate_keywords = [
99
+ "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
100
+ "erotic", "sensual", "seductive", "provocative", "intimate",
101
+ "violence", "gore", "blood", "death", "kill", "murder", "torture",
102
+ "drug", "suicide", "abuse", "hate", "discrimination"
103
+ ]
104
+
105
+ prompt_lower = prompt.lower()
106
+ for keyword in inappropriate_keywords:
107
+ if keyword in prompt_lower:
108
+ return False, "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค."
109
+ return True, prompt
110
+
111
+ def process_prompt(prompt):
112
+ """ํ”„๋กฌํ”„ํŠธ ์ „์ฒ˜๋ฆฌ (๋ฒˆ์—ญ ๋ฐ ํ•„ํ„ฐ๋ง)"""
113
+ translated_prompt = translate_if_korean(prompt)
114
+ is_safe, filtered_prompt = filter_prompt(translated_prompt)
115
+ return is_safe, filtered_prompt
116
+
117
+ class timer:
118
+ def __init__(self, method_name="timed process"):
119
+ self.method = method_name
120
+ def __enter__(self):
121
+ self.start = time.time()
122
+ print(f"{self.method} starts")
123
+ def __exit__(self, exc_type, exc_val, exc_tb):
124
+ end = time.time()
125
+ print(f"{self.method} took {str(round(end - self.start, 2))}s")
126
+
127
+ # Model initialization
128
+ if not path.exists(cache_path):
129
+ os.makedirs(cache_path, exist_ok=True)
130
+
131
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
132
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
133
+ pipe.fuse_lora(lora_scale=0.125)
134
+ pipe.to(device="cuda", dtype=torch.bfloat16)
135
+ pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
136
+
137
+ def upload_to_catbox(image_path):
138
+ """catbox.moe API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ"""
139
+ try:
140
+ logger.info(f"Preparing to upload image: {image_path}")
141
+ url = "https://catbox.moe/user/api.php"
142
+
143
+ file_extension = Path(image_path).suffix.lower()
144
+ if file_extension not in ['.jpg', '.jpeg', '.png', '.gif']:
145
+ logger.error(f"Unsupported file type: {file_extension}")
146
+ return None
147
+
148
+ files = {
149
+ 'fileToUpload': (
150
+ os.path.basename(image_path),
151
+ open(image_path, 'rb'),
152
+ 'image/jpeg' if file_extension in ['.jpg', '.jpeg'] else 'image/png'
153
+ )
154
+ }
155
+
156
+ data = {
157
+ 'reqtype': 'fileupload',
158
+ 'userhash': CATBOX_USER_HASH
159
+ }
160
+
161
+ response = requests.post(url, files=files, data=data)
162
+
163
+ if response.status_code == 200 and response.text.startswith('http'):
164
+ image_url = response.text
165
+ logger.info(f"Image uploaded successfully: {image_url}")
166
+ return image_url
167
+ else:
168
+ raise Exception(f"Upload failed: {response.text}")
169
+
170
+ except Exception as e:
171
+ logger.error(f"Image upload error: {str(e)}")
172
+ return None
173
+
174
+ def add_watermark(video_path):
175
+ """OpenCV๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋น„๋””์˜ค์— ์›Œํ„ฐ๋งˆํฌ ์ถ”๊ฐ€"""
176
+ try:
177
+ cap = cv2.VideoCapture(video_path)
178
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
179
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
180
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
181
+
182
+ text = "GiniGEN.AI"
183
+ font = cv2.FONT_HERSHEY_SIMPLEX
184
+ font_scale = height * 0.05 / 30
185
+ thickness = 2
186
+ color = (255, 255, 255)
187
+
188
+ (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
189
+ margin = int(height * 0.02)
190
+ x_pos = width - text_width - margin
191
+ y_pos = height - margin
192
+
193
+ output_path = "watermarked_output.mp4"
194
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
195
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
196
+
197
+ while cap.isOpened():
198
+ ret, frame = cap.read()
199
+ if not ret:
200
+ break
201
+ cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
202
+ out.write(frame)
203
+
204
+ cap.release()
205
+ out.release()
206
+
207
+ return output_path
208
+
209
+ except Exception as e:
210
+ logger.error(f"Error adding watermark: {str(e)}")
211
+ return video_path
212
+
213
+ def generate_video(image, prompt):
214
+ logger.info("Starting video generation")
215
+ try:
216
+ if not check_api_key():
217
+ return "Replicate API key not properly configured"
218
+
219
+ if not image:
220
+ logger.error("No image provided")
221
+ return "Please upload an image"
222
+
223
+ image_url = upload_to_catbox(image)
224
+ if not image_url:
225
+ return "Failed to upload image"
226
+
227
+ input_data = {
228
+ "prompt": prompt,
229
+ "first_frame_image": image_url
230
+ }
231
+
232
+ try:
233
+ replicate.Client(api_token=REPLICATE_API_TOKEN)
234
+ output = replicate.run(
235
+ "minimax/video-01-live",
236
+ input=input_data
237
+ )
238
+
239
+ temp_file = "temp_output.mp4"
240
+
241
+ if hasattr(output, 'read'):
242
+ with open(temp_file, "wb") as file:
243
+ file.write(output.read())
244
+ elif isinstance(output, str):
245
+ response = requests.get(output)
246
+ with open(temp_file, "wb") as file:
247
+ file.write(response.content)
248
+
249
+ final_video = add_watermark(temp_file)
250
+ return final_video
251
+
252
+ except Exception as api_error:
253
+ logger.error(f"API call failed: {str(api_error)}")
254
+ return f"API call failed: {str(api_error)}"
255
+
256
+ except Exception as e:
257
+ logger.error(f"Unexpected error: {str(e)}")
258
+ return f"Unexpected error: {str(e)}"
259
+
260
+ def save_image(image):
261
+ """Save the generated image in PNG format and return the path"""
262
+ try:
263
+ if not os.path.exists(gallery_path):
264
+ os.makedirs(gallery_path, exist_ok=True)
265
+
266
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
267
+ random_suffix = os.urandom(4).hex()
268
+ filename = f"generated_{timestamp}_{random_suffix}.png"
269
+ filepath = os.path.join(gallery_path, filename)
270
+
271
+ # PIL Image๋กœ ๋ณ€ํ™˜
272
+ if not isinstance(image, Image.Image):
273
+ image = Image.fromarray(image)
274
+
275
+ # RGB ๋ชจ๋“œ๋กœ ๋ณ€ํ™˜ (RGBA์—์„œ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋Š” ๋ฌธ์ œ ๋ฐฉ์ง€)
276
+ if image.mode != 'RGB':
277
+ image = image.convert('RGB')
278
+
279
+ # PNG ํ˜•์‹์œผ๋กœ ๋ช…์‹œ์  ์ €์žฅ
280
+ image.save(
281
+ filepath,
282
+ format='PNG',
283
+ optimize=True,
284
+ quality=100 # ์ตœ๊ณ  ํ’ˆ์งˆ
285
+ )
286
+
287
+ logger.info(f"Image saved successfully as PNG: {filepath}")
288
+ return filepath
289
+ except Exception as e:
290
+ logger.error(f"Error in save_image: {str(e)}")
291
+ return None
292
 
293
+ def load_gallery():
294
+ """Load all images from the gallery directory"""
295
+ try:
296
+ os.makedirs(gallery_path, exist_ok=True)
297
+
298
+ image_files = []
299
+ for f in os.listdir(gallery_path):
300
+ if f.lower().endswith(('.png', '.jpg', '.jpeg')):
301
+ full_path = os.path.join(gallery_path, f)
302
+ image_files.append((full_path, os.path.getmtime(full_path)))
303
+
304
+ image_files.sort(key=lambda x: x[1], reverse=True)
305
+ return [f[0] for f in image_files]
306
+ except Exception as e:
307
+ print(f"Error loading gallery: {str(e)}")
308
+ return []
309
+
310
+
311
  # ํ•œ๊ธ€-์˜์–ด ๋ฒˆ์—ญ๊ธฐ ์ดˆ๊ธฐํ™”
312
  translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
313