File size: 1,978 Bytes
982b2c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# ๋ฒ ์ด์ค ๋ชจ๋ธ์์ ํ ํฌ๋์ด์ ๋ถ๋ฌ์ค๊ธฐ
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
# ์ ์ฅํ ๋ชจ๋ธ ๊ฒฝ๋ก
model_dir = './gemma_outputs/gemma-2b-it-sum-ko-beans-1'
model = AutoModelForCausalLM.from_pretrained(model_dir)
# tokenizer = AutoTokenizer.from_pretrained(model_dir)
# ๋ชจ๋ธ์ CPU๋ก ์ด๋ (๋ง์ฝ GPU๋ฅผ ์ด๋ค๋ฉด 'cuda'๋ก ๋ฐ๊ฟ์ค)
model.to("cpu") #cpu
conversation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
def chat_with_model(input_text):
# ๋ํ์ฉ ํ๋กฌํํธ๋ฅผ ์์ฑ
messages = [{"role": "user", "content": input_text}]
# ํ ํฌ๋์ด์ ๋ก ์
๋ ฅ์ ํ๋กฌํํธ ํํ๋ก ๋ณํ
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# ๋ชจ๋ธ์ด ์๋ต์ ์์ฑ
# response = conversation_pipeline(prompt, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
response = conversation_pipeline(prompt, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, add_special_tokens=True)
# ๋ชจ๋ธ์ ์์ฑ๋ ์๋ต ์ถ์ถ
generated_text = response[0]["generated_text"]
model_response = generated_text[len(prompt):] # ์
๋ ฅ ํ๋กฌํํธ๋ฅผ ์ ๊ฑฐํ๊ณ ์๋ต๋ง ๋ฐํ
return model_response
# ๋ํ๋ฅผ ๊ณ์ ์ด์ด๋๊ฐ ์ ์๋ ๊ตฌ์กฐ
def interactive_chat():
print("๋ํํ ๋ชจ๋์ ์ค์ ๊ฒ์ ํ์ํฉ๋๋ค! '์ข
๋ฃ'๋ผ๊ณ ์
๋ ฅํ๋ฉด ๋ํ๊ฐ ์ข
๋ฃ๋ฉ๋๋ค.")
while True:
user_input = input("์ฌ์ฉ์: ") # ์ฌ์ฉ์ ์
๋ ฅ ๋ฐ๊ธฐ
if user_input.lower() == "์ข
๋ฃ": # '์ข
๋ฃ'๋ผ๊ณ ์
๋ ฅํ๋ฉด ๋ํ ์ข
๋ฃ
print("๋ํ๋ฅผ ์ข
๋ฃํฉ๋๋ค.")
break
model_reply = chat_with_model(user_input) # ๋ชจ๋ธ์ ์๋ต ๋ฐ๊ธฐ
print(f"๋ชจ๋ธ: {model_reply}") # ๋ชจ๋ธ์ ์๋ต ์ถ๋ ฅ
# ๋ํ ์์
interactive_chat() |