File size: 2,100 Bytes
af4574d
470180a
 
81b24be
af4574d
cb44112
81b24be
 
 
470180a
fefdb18
470180a
fefdb18
 
470180a
 
fefdb18
 
 
 
470180a
af4574d
 
 
470180a
af4574d
 
 
 
 
470180a
 
 
 
 
af4574d
fefdb18
470180a
af4574d
fefdb18
470180a
 
 
af4574d
 
470180a
 
 
 
af4574d
fefdb18
470180a
 
af4574d
fefdb18
470180a
 
 
af4574d
fefdb18
af4574d
470180a
af4574d
fefdb18
 
 
 
af4574d
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from transformers import LlamaTokenizer, LlamaForCausalLM
import torch
import os  # os λͺ¨λ“ˆ μž„ν¬νŠΈ

model_id = 'Bllossom/llama-3.2-Korean-Bllossom-3B'

# ν™˜κ²½ λ³€μˆ˜μ—μ„œ μ•‘μ„ΈμŠ€ 토큰 κ°€μ Έμ˜€κΈ°
hf_access_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')

# ν† ν¬λ‚˜μ΄μ €μ™€ λͺ¨λΈ λ‘œλ“œ
tokenizer = LlamaTokenizer.from_pretrained(
    model_id,
    use_auth_token=hf_access_token
)
model = LlamaForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_auth_token=hf_access_token
)

def respond(
    message,
    history,
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # ν”„λ‘¬ν”„νŠΈ 생성
    prompt = system_message + "\n"
    for user_msg, bot_msg in history:
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    prompt += f"User: {message}\nAssistant:"

    # μž…λ ₯ 토큰화
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # λͺ¨λΈ 응닡 생성
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )

    # 응닡 λ””μ½”λ”©
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response[len(prompt):].strip()

    # νžˆμŠ€ν† λ¦¬μ— μΆ”κ°€
    history.append((message, response))

    return history

# Gradio μΈν„°νŽ˜μ΄μŠ€ 생성
demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly Chatbot.",
            label="System message"
        ),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)

if __name__ == "__main__":
    demo.launch()