BEANs / gemma_Ko_coffee_load_model.py
joeykims's picture
Feat: gemma coffee model
982b2c3
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# ๋ฒ ์ด์Šค ๋ชจ๋ธ์—์„œ ํ† ํฌ๋‚˜์ด์ € ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
# ์ €์žฅํ•œ ๋ชจ๋ธ ๊ฒฝ๋กœ
model_dir = './gemma_outputs/gemma-2b-it-sum-ko-beans-1'
model = AutoModelForCausalLM.from_pretrained(model_dir)
# tokenizer = AutoTokenizer.from_pretrained(model_dir)
# ๋ชจ๋ธ์„ CPU๋กœ ์ด๋™ (๋งŒ์•ฝ GPU๋ฅผ ์“ด๋‹ค๋ฉด 'cuda'๋กœ ๋ฐ”๊ฟ”์ค˜)
model.to("cpu") #cpu
conversation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
def chat_with_model(input_text):
# ๋Œ€ํ™”์šฉ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ƒ์„ฑ
messages = [{"role": "user", "content": input_text}]
# ํ† ํฌ๋‚˜์ด์ €๋กœ ์ž…๋ ฅ์„ ํ”„๋กฌํ”„ํŠธ ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# ๋ชจ๋ธ์ด ์‘๋‹ต์„ ์ƒ์„ฑ
# response = conversation_pipeline(prompt, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
response = conversation_pipeline(prompt, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, add_special_tokens=True)
# ๋ชจ๋ธ์˜ ์ƒ์„ฑ๋œ ์‘๋‹ต ์ถ”์ถœ
generated_text = response[0]["generated_text"]
model_response = generated_text[len(prompt):] # ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ œ๊ฑฐํ•˜๊ณ  ์‘๋‹ต๋งŒ ๋ฐ˜ํ™˜
return model_response
# ๋Œ€ํ™”๋ฅผ ๊ณ„์† ์ด์–ด๋‚˜๊ฐˆ ์ˆ˜ ์žˆ๋Š” ๊ตฌ์กฐ
def interactive_chat():
print("๋Œ€ํ™”ํ˜• ๋ชจ๋“œ์— ์˜ค์‹  ๊ฒƒ์„ ํ™˜์˜ํ•ฉ๋‹ˆ๋‹ค! '์ข…๋ฃŒ'๋ผ๊ณ  ์ž…๋ ฅํ•˜๋ฉด ๋Œ€ํ™”๊ฐ€ ์ข…๋ฃŒ๋ฉ๋‹ˆ๋‹ค.")
while True:
user_input = input("์‚ฌ์šฉ์ž: ") # ์‚ฌ์šฉ์ž ์ž…๋ ฅ ๋ฐ›๊ธฐ
if user_input.lower() == "์ข…๋ฃŒ": # '์ข…๋ฃŒ'๋ผ๊ณ  ์ž…๋ ฅํ•˜๋ฉด ๋Œ€ํ™” ์ข…๋ฃŒ
print("๋Œ€ํ™”๋ฅผ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
break
model_reply = chat_with_model(user_input) # ๋ชจ๋ธ์˜ ์‘๋‹ต ๋ฐ›๊ธฐ
print(f"๋ชจ๋ธ: {model_reply}") # ๋ชจ๋ธ์˜ ์‘๋‹ต ์ถœ๋ ฅ
# ๋Œ€ํ™” ์‹œ์ž‘
interactive_chat()