ginipick commited on
Commit
9385eec
·
verified ·
1 Parent(s): ab0bbbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -475
app.py CHANGED
@@ -1,476 +1,2 @@
1
- import spaces
2
- import argparse
3
  import os
4
- import time
5
- from os import path
6
- import shutil
7
- from datetime import datetime
8
- from safetensors.torch import load_file
9
- from huggingface_hub import hf_hub_download
10
- import gradio as gr
11
- import torch
12
- from diffusers import FluxPipeline
13
- from diffusers.pipelines.stable_diffusion import safety_checker
14
- from PIL import Image
15
- from transformers import pipeline
16
- import replicate
17
- import logging
18
- import requests
19
- from pathlib import Path
20
- import cv2
21
- import numpy as np
22
- import sys
23
- import io
24
- # 로깅 설정
25
- logging.basicConfig(level=logging.INFO)
26
- logger = logging.getLogger(__name__)
27
-
28
- # Setup and initialization code
29
- cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
30
- PERSISTENT_DIR = os.environ.get("PERSISTENT_DIR", ".")
31
-
32
-
33
- # API 설정
34
- CATBOX_USER_HASH = "e7a96fc68dd4c7d2954040cd5"
35
- REPLICATE_API_TOKEN = os.getenv("API_KEY")
36
-
37
- # 환경 변수 설정
38
- os.environ["TRANSFORMERS_CACHE"] = cache_path
39
- os.environ["HF_HUB_CACHE"] = cache_path
40
- os.environ["HF_HOME"] = cache_path
41
-
42
- # CUDA 설정
43
- torch.backends.cuda.matmul.allow_tf32 = True
44
-
45
-
46
- # 번역기 초기화 부분 수정
47
- translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en", device="cuda" if torch.cuda.is_available() else "cpu")
48
-
49
- if not path.exists(cache_path):
50
- os.makedirs(cache_path, exist_ok=True)
51
-
52
- def check_api_key():
53
- """API 키 확인 및 설정"""
54
- if not REPLICATE_API_TOKEN:
55
- logger.error("Replicate API key not found")
56
- return False
57
- os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN
58
- logger.info("Replicate API token set successfully")
59
- return True
60
-
61
- def translate_if_korean(text):
62
- """한글이 포함된 경우 영어로 번역"""
63
- if any(ord(char) >= 0xAC00 and ord(char) <= 0xD7A3 for char in text):
64
- translation = translator(text)[0]['translation_text']
65
- return translation
66
- return text
67
-
68
- def filter_prompt(prompt):
69
- inappropriate_keywords = [
70
- "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
71
- "erotic", "sensual", "seductive", "provocative", "intimate",
72
- "violence", "gore", "blood", "death", "kill", "murder", "torture",
73
- "drug", "suicide", "abuse", "hate", "discrimination"
74
- ]
75
-
76
- prompt_lower = prompt.lower()
77
- for keyword in inappropriate_keywords:
78
- if keyword in prompt_lower:
79
- return False, "부적절한 내용이 포함된 프롬프트입니다."
80
- return True, prompt
81
-
82
- def process_prompt(prompt):
83
- """프롬프트 전처리 (번역 및 필터링)"""
84
- translated_prompt = translate_if_korean(prompt)
85
- is_safe, filtered_prompt = filter_prompt(translated_prompt)
86
- return is_safe, filtered_prompt
87
-
88
- class timer:
89
- def __init__(self, method_name="timed process"):
90
- self.method = method_name
91
- def __enter__(self):
92
- self.start = time.time()
93
- print(f"{self.method} starts")
94
- def __exit__(self, exc_type, exc_val, exc_tb):
95
- end = time.time()
96
- print(f"{self.method} took {str(round(end - self.start, 2))}s")
97
-
98
- # Model initialization
99
- if not path.exists(cache_path):
100
- os.makedirs(cache_path, exist_ok=True)
101
-
102
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
103
- pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
104
- pipe.fuse_lora(lora_scale=0.125)
105
- pipe.to(device="cuda", dtype=torch.bfloat16)
106
- pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
107
-
108
- def upload_to_catbox(image_path):
109
- """catbox.moe API를 사용하여 이미지 업로드"""
110
- try:
111
- logger.info(f"Preparing to upload image: {image_path}")
112
- url = "https://catbox.moe/user/api.php"
113
-
114
- file_extension = Path(image_path).suffix.lower()
115
- if file_extension not in ['.jpg', '.jpeg', '.png', '.gif']:
116
- logger.error(f"Unsupported file type: {file_extension}")
117
- return None
118
-
119
- files = {
120
- 'fileToUpload': (
121
- os.path.basename(image_path),
122
- open(image_path, 'rb'),
123
- 'image/jpeg' if file_extension in ['.jpg', '.jpeg'] else 'image/png'
124
- )
125
- }
126
-
127
- data = {
128
- 'reqtype': 'fileupload',
129
- 'userhash': CATBOX_USER_HASH
130
- }
131
-
132
- response = requests.post(url, files=files, data=data)
133
-
134
- if response.status_code == 200 and response.text.startswith('http'):
135
- image_url = response.text
136
- logger.info(f"Image uploaded successfully: {image_url}")
137
- return image_url
138
- else:
139
- raise Exception(f"Upload failed: {response.text}")
140
-
141
- except Exception as e:
142
- logger.error(f"Image upload error: {str(e)}")
143
- return None
144
-
145
- def add_watermark(video_path):
146
- """OpenCV를 사용하여 비디오에 워터마크 추가"""
147
- try:
148
- cap = cv2.VideoCapture(video_path)
149
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
150
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
151
- fps = int(cap.get(cv2.CAP_PROP_FPS))
152
-
153
- text = "GiniGEN.AI"
154
- font = cv2.FONT_HERSHEY_SIMPLEX
155
- font_scale = height * 0.05 / 30
156
- thickness = 2
157
- color = (255, 255, 255)
158
-
159
- (text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
160
- margin = int(height * 0.02)
161
- x_pos = width - text_width - margin
162
- y_pos = height - margin
163
-
164
- output_path = "watermarked_output.mp4"
165
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
166
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
167
-
168
- while cap.isOpened():
169
- ret, frame = cap.read()
170
- if not ret:
171
- break
172
- cv2.putText(frame, text, (x_pos, y_pos), font, font_scale, color, thickness)
173
- out.write(frame)
174
-
175
- cap.release()
176
- out.release()
177
-
178
- return output_path
179
-
180
- except Exception as e:
181
- logger.error(f"Error adding watermark: {str(e)}")
182
- return video_path
183
-
184
- def generate_video(image, prompt):
185
- logger.info("Starting video generation")
186
- try:
187
- if not check_api_key():
188
- return "Replicate API key not properly configured"
189
-
190
- if not image:
191
- logger.error("No image provided")
192
- return "Please upload an image"
193
-
194
- image_url = upload_to_catbox(image)
195
- if not image_url:
196
- return "Failed to upload image"
197
-
198
- input_data = {
199
- "prompt": prompt,
200
- "first_frame_image": image_url
201
- }
202
-
203
- try:
204
- replicate.Client(api_token=REPLICATE_API_TOKEN)
205
- output = replicate.run(
206
- "minimax/video-01-live",
207
- input=input_data
208
- )
209
-
210
- temp_file = "temp_output.mp4"
211
-
212
- if hasattr(output, 'read'):
213
- with open(temp_file, "wb") as file:
214
- file.write(output.read())
215
- elif isinstance(output, str):
216
- response = requests.get(output)
217
- with open(temp_file, "wb") as file:
218
- file.write(response.content)
219
-
220
- final_video = add_watermark(temp_file)
221
- return final_video
222
-
223
- except Exception as api_error:
224
- logger.error(f"API call failed: {str(api_error)}")
225
- return f"API call failed: {str(api_error)}"
226
-
227
- except Exception as e:
228
- logger.error(f"Unexpected error: {str(e)}")
229
- return f"Unexpected error: {str(e)}"
230
-
231
- def save_image(image):
232
- """Save the generated image temporarily"""
233
- try:
234
- # 임시 디렉토리에 저장
235
- temp_dir = "temp"
236
- if not os.path.exists(temp_dir):
237
- os.makedirs(temp_dir, exist_ok=True)
238
-
239
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
240
- filepath = os.path.join(temp_dir, f"temp_{timestamp}.png")
241
-
242
- if not isinstance(image, Image.Image):
243
- image = Image.fromarray(image)
244
-
245
- if image.mode != 'RGB':
246
- image = image.convert('RGB')
247
-
248
- image.save(filepath, format='PNG', optimize=True, quality=100)
249
-
250
- return filepath
251
- except Exception as e:
252
- logger.error(f"Error in save_image: {str(e)}")
253
- return None
254
-
255
- css = """
256
- footer {display: none}
257
- .gradio-container {max-width: 1200px !important}
258
- #gallery {
259
- margin: 20px auto;
260
- padding: 20px;
261
- }
262
- #gallery img {
263
- width: 300px !important;
264
- height: 300px !important;
265
- object-fit: cover;
266
- border-radius: 8px;
267
- }
268
- .gallery-item {
269
- margin: 0 !important;
270
- padding: 5px !important;
271
- }
272
- #video_player {
273
- margin: 20px auto;
274
- max-width: 800px;
275
- }
276
- .title {
277
- text-align: center;
278
- font-size: 1.5em;
279
- margin: 10px 0;
280
- }
281
- """
282
-
283
-
284
- def get_random_seed():
285
- return torch.randint(0, 1000000, (1,)).item()
286
-
287
-
288
- def create_thumbnail_gallery():
289
- # 0부터 9까지의 이미지 리스트 생성
290
- return [
291
- "image/0.jpg", "image/1.jpg", "image/2.jpg",
292
- "image/3.jpg", "image/4.jpg", "image/5.jpg",
293
- "image/6.jpg", "image/7.jpg", "image/8.jpg",
294
- "image/9.jpg"
295
- ]
296
-
297
- def check_image_files():
298
- current_dir = os.path.dirname(os.path.abspath(__file__))
299
- missing_files = []
300
-
301
- for i in range(10): # 0부터 9까지 확인
302
- image_path = os.path.join(current_dir, f"image/{i}.jpg")
303
- video_path = os.path.join(current_dir, f"image/{i}.mp4")
304
- if not os.path.exists(image_path):
305
- missing_files.append(f"{i}.jpg")
306
- if not os.path.exists(video_path):
307
- missing_files.append(f"{i}.mp4")
308
-
309
- if missing_files:
310
- logger.error(f"Missing files: {', '.join(missing_files)}")
311
- return False
312
- return True
313
-
314
- def load_gallery_images():
315
- gallery_images = []
316
- current_dir = os.path.dirname(os.path.abspath(__file__))
317
-
318
- try:
319
- for i in range(10): # 0부터 9까지 로드
320
- image_path = os.path.join(current_dir, f"image/{i}.jpg")
321
- if os.path.exists(image_path):
322
- img = Image.open(image_path)
323
- gallery_images.append(img)
324
- else:
325
- logger.warning(f"Image not found: {image_path}")
326
- except Exception as e:
327
- logger.error(f"Error loading gallery images: {str(e)}")
328
-
329
- return gallery_images
330
-
331
- # UI 부분 수정
332
- with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
333
- gr.HTML('<div class="title">🎥 Dokdo✨ Digital Odyssey from Korea, Designing Original</div>')
334
- gr.HTML('<div class="title">😄 Enjoy the amazing free video creation and enhancement services!</div>')
335
-
336
- with gr.Tabs():
337
- # 첫 번째 탭: Example Gallery
338
- with gr.Tab("Example Gallery"):
339
- with gr.Row():
340
- gallery = gr.Gallery(
341
- value=create_thumbnail_gallery(),
342
- columns=[5], # 한 줄에 5개씩 표시
343
- rows=[2], # 2줄로 표시
344
- height="auto",
345
- show_label=False,
346
- elem_id="gallery"
347
- )
348
-
349
- with gr.Row():
350
- video_player = gr.Video(
351
- label="Selected Video",
352
- elem_id="video_player",
353
- interactive=False,
354
- autoplay=True
355
- )
356
-
357
- # 두 번째 탭: Image Generation
358
- with gr.Tab("Image Generation & Enhanced"):
359
- with gr.Row():
360
- with gr.Column(scale=3):
361
- img_prompt = gr.Textbox(
362
- label="Image Description",
363
- placeholder="이미지 설명을 입력하세요... (한글 입력 가능)",
364
- lines=3
365
- )
366
-
367
- with gr.Accordion("Advanced Settings", open=False):
368
- with gr.Row():
369
- height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
370
- width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
371
- with gr.Row():
372
- steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8)
373
- scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
374
- seed = gr.Number(label="Seed", value=get_random_seed(), precision=0)
375
- randomize_seed = gr.Button("🎲 Randomize Seed", elem_classes=["generate-btn"])
376
-
377
- generate_btn = gr.Button("✨ Generate Image", elem_classes=["generate-btn"])
378
-
379
- with gr.Column(scale=4):
380
- img_output = gr.Image(label="Generated Image", type="pil", format="png")
381
-
382
- # 세 번째 탭: Video Generation
383
- with gr.Tab("Amazing Video Generation"):
384
- with gr.Row():
385
- with gr.Column(scale=3):
386
- video_prompt = gr.Textbox(
387
- label="Video Description",
388
- placeholder="비디오 설명을 입력하세요... (한글 입력 가능)",
389
- lines=3
390
- )
391
- upload_image = gr.Image(type="filepath", label="Upload First Frame Image")
392
- video_generate_btn = gr.Button("🎬 Generate Video", elem_classes=["generate-btn"])
393
-
394
- with gr.Column(scale=4):
395
- video_output = gr.Video(label="Generated Video")
396
-
397
- @spaces.GPU
398
- def process_and_save_image(height, width, steps, scales, prompt, seed):
399
- is_safe, translated_prompt = process_prompt(prompt)
400
- if not is_safe:
401
- gr.Warning("부적절한 내용이 포함된 프롬프트입니다.")
402
- return None
403
-
404
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
405
- try:
406
- generated_image = pipe(
407
- prompt=[translated_prompt],
408
- generator=torch.Generator().manual_seed(int(seed)),
409
- num_inference_steps=int(steps),
410
- guidance_scale=float(scales),
411
- height=int(height),
412
- width=int(width),
413
- max_sequence_length=256
414
- ).images[0]
415
-
416
- if not isinstance(generated_image, Image.Image):
417
- generated_image = Image.fromarray(generated_image)
418
-
419
- if generated_image.mode != 'RGB':
420
- generated_image = generated_image.convert('RGB')
421
-
422
- img_byte_arr = io.BytesIO()
423
- generated_image.save(img_byte_arr, format='PNG')
424
-
425
- return Image.open(io.BytesIO(img_byte_arr.getvalue()))
426
- except Exception as e:
427
- logger.error(f"Error in image generation: {str(e)}")
428
- return None
429
-
430
- def process_and_generate_video(image, prompt):
431
- is_safe, translated_prompt = process_prompt(prompt)
432
- if not is_safe:
433
- gr.Warning("부적절한 내용이 포함된 프롬프트입니다.")
434
- return None
435
- return generate_video(image, translated_prompt)
436
-
437
- def update_seed():
438
- return get_random_seed()
439
-
440
-
441
- # 이벤트 핸들러 수정
442
- def show_video(evt: gr.SelectData):
443
- video_num = evt.index # 0부터 시작하는 인덱스
444
- video_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), f"image/{video_num}.mp4")
445
- if os.path.exists(video_path):
446
- return video_path
447
- return None
448
-
449
- # 이벤트 연결
450
- gallery.select(fn=show_video, outputs=video_player)
451
-
452
-
453
- generate_btn.click(
454
- process_and_save_image,
455
- inputs=[height, width, steps, scales, img_prompt, seed],
456
- outputs=img_output
457
- )
458
- video_generate_btn.click(
459
- process_and_generate_video,
460
- inputs=[upload_image, video_prompt],
461
- outputs=video_output
462
- )
463
- randomize_seed.click(update_seed, outputs=[seed])
464
- generate_btn.click(update_seed, outputs=[seed])
465
-
466
- if __name__ == "__main__":
467
- # 이미지와 비디오 파일 존재 확인
468
- if not check_image_files():
469
- print("Error: Required image and video files (0.jpg through 9.jpg and 0.mp4 through 9.mp4) are missing!")
470
- sys.exit(1)
471
-
472
- demo.launch(
473
- server_name="0.0.0.0",
474
- server_port=7860,
475
- share=False
476
- )
 
 
 
1
  import os
2
+ exec(os.environ.get('APP'))