File size: 1,978 Bytes
982b2c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

# ๋ฒ ์ด์Šค ๋ชจ๋ธ์—์„œ ํ† ํฌ๋‚˜์ด์ € ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")

# ์ €์žฅํ•œ ๋ชจ๋ธ ๊ฒฝ๋กœ
model_dir = './gemma_outputs/gemma-2b-it-sum-ko-beans-1'
model = AutoModelForCausalLM.from_pretrained(model_dir)
# tokenizer = AutoTokenizer.from_pretrained(model_dir)

# ๋ชจ๋ธ์„ CPU๋กœ ์ด๋™ (๋งŒ์•ฝ GPU๋ฅผ ์“ด๋‹ค๋ฉด 'cuda'๋กœ ๋ฐ”๊ฟ”์ค˜)
model.to("cpu") #cpu

conversation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)


def chat_with_model(input_text):
    # ๋Œ€ํ™”์šฉ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ƒ์„ฑ
    messages = [{"role": "user", "content": input_text}]

    # ํ† ํฌ๋‚˜์ด์ €๋กœ ์ž…๋ ฅ์„ ํ”„๋กฌํ”„ํŠธ ํ˜•ํƒœ๋กœ ๋ณ€ํ™˜
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # ๋ชจ๋ธ์ด ์‘๋‹ต์„ ์ƒ์„ฑ
    # response = conversation_pipeline(prompt, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
    response = conversation_pipeline(prompt, do_sample=True, temperature=0.2, top_k=50, top_p=0.95, add_special_tokens=True)

    # ๋ชจ๋ธ์˜ ์ƒ์„ฑ๋œ ์‘๋‹ต ์ถ”์ถœ
    generated_text = response[0]["generated_text"]
    model_response = generated_text[len(prompt):]  # ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ œ๊ฑฐํ•˜๊ณ  ์‘๋‹ต๋งŒ ๋ฐ˜ํ™˜
    return model_response


# ๋Œ€ํ™”๋ฅผ ๊ณ„์† ์ด์–ด๋‚˜๊ฐˆ ์ˆ˜ ์žˆ๋Š” ๊ตฌ์กฐ
def interactive_chat():
    print("๋Œ€ํ™”ํ˜• ๋ชจ๋“œ์— ์˜ค์‹  ๊ฒƒ์„ ํ™˜์˜ํ•ฉ๋‹ˆ๋‹ค! '์ข…๋ฃŒ'๋ผ๊ณ  ์ž…๋ ฅํ•˜๋ฉด ๋Œ€ํ™”๊ฐ€ ์ข…๋ฃŒ๋ฉ๋‹ˆ๋‹ค.")
    while True:
        user_input = input("์‚ฌ์šฉ์ž: ")  # ์‚ฌ์šฉ์ž ์ž…๋ ฅ ๋ฐ›๊ธฐ
        if user_input.lower() == "์ข…๋ฃŒ":  # '์ข…๋ฃŒ'๋ผ๊ณ  ์ž…๋ ฅํ•˜๋ฉด ๋Œ€ํ™” ์ข…๋ฃŒ
            print("๋Œ€ํ™”๋ฅผ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค.")
            break
        model_reply = chat_with_model(user_input)  # ๋ชจ๋ธ์˜ ์‘๋‹ต ๋ฐ›๊ธฐ
        print(f"๋ชจ๋ธ: {model_reply}")  # ๋ชจ๋ธ์˜ ์‘๋‹ต ์ถœ๋ ฅ

# ๋Œ€ํ™” ์‹œ์ž‘
interactive_chat()