|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it") |
|
|
|
|
|
model_dir = './gemma_outputs/gemma-2b-it-sum-ko-beans-1' |
|
model = AutoModelForCausalLM.from_pretrained(model_dir) |
|
|
|
|
|
|
|
model.to("cpu") |
|
|
|
conversation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512) |
|
|
|
|
|
def chat_with_model(input_text): |
|
|
|
messages = [{"role": "user", "content": input_text}] |
|
|
|
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
|
|
response = conversation_pipeline(prompt, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, add_special_tokens=True) |
|
|
|
|
|
generated_text = response[0]["generated_text"] |
|
model_response = generated_text[len(prompt):] |
|
return model_response |
|
|
|
|
|
|
|
def interactive_chat(): |
|
print("๋ํํ ๋ชจ๋์ ์ค์ ๊ฒ์ ํ์ํฉ๋๋ค! '์ข
๋ฃ'๋ผ๊ณ ์
๋ ฅํ๋ฉด ๋ํ๊ฐ ์ข
๋ฃ๋ฉ๋๋ค.") |
|
while True: |
|
user_input = input("์ฌ์ฉ์: ") |
|
if user_input.lower() == "์ข
๋ฃ": |
|
print("๋ํ๋ฅผ ์ข
๋ฃํฉ๋๋ค.") |
|
break |
|
model_reply = chat_with_model(user_input) |
|
print(f"๋ชจ๋ธ: {model_reply}") |
|
|
|
|
|
interactive_chat() |