ginipick's picture
Update app.py
b94c22c verified
raw
history blame
10 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gc
import os
import datetime
import time
import spaces
# --- ์„ค์ • ---
MODEL_ID = "naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B"
MAX_NEW_TOKENS = 512
CPU_THREAD_COUNT = 4 # ํ•„์š”์‹œ ์กฐ์ ˆ
# Hugging Face ํ† ํฐ ์„ค์ • - ํ™˜๊ฒฝ ๋ณ€์ˆ˜์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
print("๊ฒฝ๊ณ : HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๋น„๊ณต๊ฐœ ๋ชจ๋ธ์— ์ ‘๊ทผํ•  ์ˆ˜ ์—†์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
# --- ์„ ํƒ ์‚ฌํ•ญ: CPU ์Šค๋ ˆ๋“œ ์„ค์ • ---
# torch.set_num_threads(CPU_THREAD_COUNT)
# os.environ["OMP_NUM_THREADS"] = str(CPU_THREAD_COUNT)
# os.environ["MKL_NUM_THREADS"] = str(CPU_THREAD_COUNT)
print("--- ํ™˜๊ฒฝ ์„ค์ • ---")
print(f"PyTorch ๋ฒ„์ „: {torch.__version__}")
print(f"์‹คํ–‰ ์žฅ์น˜: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")
print(f"Torch ์Šค๋ ˆ๋“œ: {torch.get_num_threads()}")
print(f"HF_TOKEN ์„ค์ • ์—ฌ๋ถ€: {'์žˆ์Œ' if HF_TOKEN else '์—†์Œ'}")
# --- ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ---
print(f"--- ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘: {MODEL_ID} ---")
print("์ฒซ ์‹คํ–‰ ์‹œ ๋ช‡ ๋ถ„ ์ •๋„ ์†Œ์š”๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค...")
model = None
tokenizer = None
load_successful = False
stop_token_ids_list = [] # stop_token_ids_list ์ดˆ๊ธฐํ™”
try:
start_load_time = time.time()
# ์ž์›์— ๋”ฐ๋ผ device_map ์„ค์ •
device_map = "auto" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
tokenizer_kwargs = {
"trust_remote_code": True
}
# HF_TOKEN์ด ์„ค์ •๋˜์–ด ์žˆ์œผ๋ฉด ์ถ”๊ฐ€
if HF_TOKEN:
tokenizer_kwargs["token"] = HF_TOKEN
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
**tokenizer_kwargs
)
# ๋ชจ๋ธ ๋กœ๋”ฉ
model_kwargs = {
"torch_dtype": dtype,
"device_map": device_map,
"trust_remote_code": True
}
# HF_TOKEN์ด ์„ค์ •๋˜์–ด ์žˆ์œผ๋ฉด ์ถ”๊ฐ€
if HF_TOKEN:
model_kwargs["token"] = HF_TOKEN
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
**model_kwargs
)
model.eval()
load_time = time.time() - start_load_time
print(f"--- ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ ์™„๋ฃŒ: {load_time:.2f}์ดˆ ์†Œ์š” ---")
load_successful = True
# --- ์ค‘์ง€ ํ† ํฐ ์„ค์ • ---
stop_token_strings = ["</s>", "<|endoftext|>"]
temp_stop_ids = [tokenizer.convert_tokens_to_ids(token) for token in stop_token_strings]
if tokenizer.eos_token_id is not None and tokenizer.eos_token_id not in temp_stop_ids:
temp_stop_ids.append(tokenizer.eos_token_id)
elif tokenizer.eos_token_id is None:
print("๊ฒฝ๊ณ : tokenizer.eos_token_id๊ฐ€ None์ž…๋‹ˆ๋‹ค. ์ค‘์ง€ ํ† ํฐ์— ์ถ”๊ฐ€ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
stop_token_ids_list = [tid for tid in temp_stop_ids if tid is not None]
if not stop_token_ids_list:
print("๊ฒฝ๊ณ : ์ค‘์ง€ ํ† ํฐ ID๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ๊ฐ€๋Šฅํ•˜๋ฉด ๊ธฐ๋ณธ EOS๋ฅผ ์‚ฌ์šฉํ•˜๊ณ , ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ์ƒ์„ฑ์ด ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ค‘์ง€๋˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
if tokenizer.eos_token_id is not None:
stop_token_ids_list = [tokenizer.eos_token_id]
else:
print("์˜ค๋ฅ˜: ๊ธฐ๋ณธ EOS๋ฅผ ํฌํ•จํ•˜์—ฌ ์ค‘์ง€ ํ† ํฐ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์ƒ์„ฑ์ด ๋ฌดํ•œ์ • ์‹คํ–‰๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
print(f"์‚ฌ์šฉํ•  ์ค‘์ง€ ํ† ํฐ ID: {stop_token_ids_list}")
except Exception as e:
print(f"!!! ๋ชจ๋ธ ๋กœ๋”ฉ ์˜ค๋ฅ˜: {e}")
if 'model' in locals() and model is not None: del model
if 'tokenizer' in locals() and tokenizer is not None: del tokenizer
gc.collect()
raise gr.Error(f"๋ชจ๋ธ {MODEL_ID} ๋กœ๋”ฉ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ์‹œ์ž‘ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์˜ค๋ฅ˜: {e}")
# --- ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ ์ •์˜ ---
def get_system_prompt():
current_date = datetime.datetime.now().strftime("%Y-%m-%d (%A)")
return (
f"- ์˜ค๋Š˜์€ {current_date}์ž…๋‹ˆ๋‹ค.\n"
f"- ์‚ฌ์šฉ์ž์˜ ์งˆ๋ฌธ์— ๋Œ€ํ•ด ์นœ์ ˆํ•˜๊ณ  ์ž์„ธํ•˜๊ฒŒ ํ•œ๊ตญ์–ด๋กœ ๋‹ต๋ณ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค."
)
# --- ์›œ์—… ํ•จ์ˆ˜ ---
def warmup_model():
if not load_successful or model is None or tokenizer is None:
print("์›œ์—… ๊ฑด๋„ˆ๋›ฐ๊ธฐ: ๋ชจ๋ธ์ด ์„ฑ๊ณต์ ์œผ๋กœ ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
return
print("--- ๋ชจ๋ธ ์›œ์—… ์‹œ์ž‘ ---")
try:
start_warmup_time = time.time()
warmup_message = "์•ˆ๋…•ํ•˜์„ธ์š”"
# ๋ชจ๋ธ์— ๋งž๋Š” ํ˜•์‹์œผ๋กœ ์ž…๋ ฅ ๊ตฌ์„ฑ
system_prompt = get_system_prompt()
# MiMo ๋ชจ๋ธ์˜ ํ”„๋กฌํ”„ํŠธ ํ˜•์‹์— ๋งž๊ฒŒ ์กฐ์ •
prompt = f"Human: {warmup_message}\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# ์ค‘์ง€ ํ† ํฐ์ด ๋น„์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•˜๊ณ  ์ ์ ˆํžˆ ์ฒ˜๋ฆฌ
gen_kwargs = {
"max_new_tokens": 10,
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
"do_sample": False
}
if stop_token_ids_list:
gen_kwargs["eos_token_id"] = stop_token_ids_list
else:
print("์›œ์—… ๊ฒฝ๊ณ : ์ƒ์„ฑ์— ์ •์˜๋œ ์ค‘์ง€ ํ† ํฐ์ด ์—†์Šต๋‹ˆ๋‹ค.")
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
del inputs
del output_ids
gc.collect()
warmup_time = time.time() - start_warmup_time
print(f"--- ๋ชจ๋ธ ์›œ์—… ์™„๋ฃŒ: {warmup_time:.2f}์ดˆ ์†Œ์š” ---")
except Exception as e:
print(f"!!! ๋ชจ๋ธ ์›œ์—… ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
finally:
gc.collect()
# --- ์ถ”๋ก  ํ•จ์ˆ˜ ---
@spaces.GPU()
def predict(message, history):
"""
HyperCLOVAX-SEED-Vision-Instruct-3B ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์‘๋‹ต์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
'history'๋Š” Gradio 'messages' ํ˜•์‹์„ ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค: List[Dict].
"""
if model is None or tokenizer is None:
return "์˜ค๋ฅ˜: ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค."
# ๋Œ€ํ™” ๊ธฐ๋ก ์ฒ˜๋ฆฌ
history_text = ""
if isinstance(history, list):
for turn in history:
if isinstance(turn, tuple) and len(turn) == 2:
history_text += f"Human: {turn[0]}\nAssistant: {turn[1]}\n"
# MiMo ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์— ๋งž๊ฒŒ ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
prompt = f"{history_text}Human: {message}\nAssistant:"
inputs = None
output_ids = None
try:
# ์ž…๋ ฅ ์ค€๋น„
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_length = inputs.input_ids.shape[1]
print(f"\n์ž…๋ ฅ ํ† ํฐ ์ˆ˜: {input_length}")
except Exception as e:
print(f"!!! ์ž…๋ ฅ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
return f"์˜ค๋ฅ˜: ์ž…๋ ฅ ํ˜•์‹์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์ค‘ ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค. ({e})"
try:
print("์‘๋‹ต ์ƒ์„ฑ ์ค‘...")
generation_start_time = time.time()
# ์ƒ์„ฑ ์ธ์ˆ˜ ์ค€๋น„, ๋น„์–ด ์žˆ๋Š” stop_token_ids_list ์ฒ˜๋ฆฌ
gen_kwargs = {
"max_new_tokens": MAX_NEW_TOKENS,
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
"do_sample": True,
"temperature": 0.7,
"top_p": 0.9,
"repetition_penalty": 1.1
}
if stop_token_ids_list:
gen_kwargs["eos_token_id"] = stop_token_ids_list
else:
print("์ƒ์„ฑ ๊ฒฝ๊ณ : ์ •์˜๋œ ์ค‘์ง€ ํ† ํฐ์ด ์—†์Šต๋‹ˆ๋‹ค.")
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
generation_time = time.time() - generation_start_time
print(f"์ƒ์„ฑ ์™„๋ฃŒ: {generation_time:.2f}์ดˆ ์†Œ์š”.")
except Exception as e:
print(f"!!! ๋ชจ๋ธ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
if inputs is not None: del inputs
if output_ids is not None: del output_ids
gc.collect()
return f"์˜ค๋ฅ˜: ์‘๋‹ต์„ ์ƒ์„ฑํ•˜๋Š” ์ค‘ ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค. ({e})"
# ์‘๋‹ต ๋””์ฝ”๋”ฉ
response = "์˜ค๋ฅ˜: ์‘๋‹ต ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."
if output_ids is not None:
try:
new_tokens = output_ids[0, input_length:]
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
print(f"์ถœ๋ ฅ ํ† ํฐ ์ˆ˜: {len(new_tokens)}")
del new_tokens
except Exception as e:
print(f"!!! ์‘๋‹ต ๋””์ฝ”๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
response = "์˜ค๋ฅ˜: ์‘๋‹ต์„ ๋””์ฝ”๋”ฉํ•˜๋Š” ์ค‘ ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
if inputs is not None: del inputs
if output_ids is not None: del output_ids
gc.collect()
print("๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ์™„๋ฃŒ.")
return response.strip()
# --- Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ---
print("--- Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ์ค‘ ---")
examples = [
["์•ˆ๋…•ํ•˜์„ธ์š”! ์ž๊ธฐ์†Œ๊ฐœ ์ข€ ํ•ด์ฃผ์„ธ์š”."],
["์ธ๊ณต์ง€๋Šฅ๊ณผ ๋จธ์‹ ๋Ÿฌ๋‹์˜ ์ฐจ์ด์ ์€ ๋ฌด์—‡์ธ๊ฐ€์š”?"],
["๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ํ•™์Šต ๊ณผ์ •์„ ๋‹จ๊ณ„๋ณ„๋กœ ์•Œ๋ ค์ฃผ์„ธ์š”."],
["์ œ์ฃผ๋„ ์—ฌํ–‰ ๊ณ„ํš์„ ์„ธ์šฐ๊ณ  ์žˆ๋Š”๋ฐ, 3๋ฐ• 4์ผ ์ถ”์ฒœ ์ฝ”์Šค ์ข€ ์•Œ๋ ค์ฃผ์„ธ์š”."],
]
# ๋ชจ๋ธ ์ด๋ฆ„์— ๋งž๊ฒŒ ํƒ€์ดํ‹€ ์กฐ์ •
title = "๐Ÿค– HyperCLOVAX-SEED-Vision-Instruct-3B"
# ChatInterface๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž์ฒด Chatbot ์ปดํฌ๋„ŒํŠธ ๊ด€๋ฆฌ
demo = gr.ChatInterface(
fn=predict,
title=title,
description=(
f"**๋ชจ๋ธ:** {MODEL_ID}\n"
),
examples=examples,
cache_examples=False,
theme=gr.themes.Soft(),
)
# --- ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์‹คํ–‰ ---
if __name__ == "__main__":
if load_successful:
warmup_model()
else:
print("๋ชจ๋ธ ๋กœ๋”ฉ์— ์‹คํŒจํ•˜์—ฌ ์›œ์—…์„ ๊ฑด๋„ˆ๋œ๋‹ˆ๋‹ค.")
print("--- Gradio ์•ฑ ์‹คํ–‰ ์ค‘ ---")
demo.queue().launch(
# share=True # ๊ณต๊ฐœ ๋งํฌ๋ฅผ ์›ํ•˜๋ฉด ์ฃผ์„ ํ•ด์ œ
# server_name="0.0.0.0" # ๋กœ์ปฌ ๋„คํŠธ์›Œํฌ ์ ‘๊ทผ์„ ์›ํ•˜๋ฉด ์ฃผ์„ ํ•ด์ œ
)