File size: 7,765 Bytes
1b3c845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEndpoint
import streamlit as st
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

# μ‚¬μš©ν•  Hugging Face λͺ¨λΈ IDλ₯Ό μ •μ˜ν•©λ‹ˆλ‹€.
model_id = "mistralai/Mistral-7B-Instruct-v0.3"

# .env 파일 λ‘œλ“œ
load_dotenv()


def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
    """
    Hugging Face 좔둠을 μœ„ν•œ μ–Έμ–΄ λͺ¨λΈμ„ λ°˜ν™˜ν•©λ‹ˆλ‹€.

    λ§€κ°œλ³€μˆ˜:
    - model_id (str): Hugging Face λͺ¨λΈ μ €μž₯μ†Œμ˜ IDμž…λ‹ˆλ‹€.
    - max_new_tokens (int): 생성할 수 μžˆλŠ” μ΅œλŒ€ μƒˆ 토큰 μˆ˜μž…λ‹ˆλ‹€.
    - temperature (float): λͺ¨λΈμ—μ„œ μƒ˜ν”Œλ§ν•  λ•Œμ˜ μ˜¨λ„ κ°’μž…λ‹ˆλ‹€.

    λ°˜ν™˜κ°’:
    - llm (HuggingFaceEndpoint): Hugging Face 좔둠을 μœ„ν•œ μ–Έμ–΄ λͺ¨λΈμž…λ‹ˆλ‹€.
    """
    # HuggingFaceEndpointλ₯Ό μ‚¬μš©ν•˜μ—¬ μ–Έμ–΄ λͺ¨λΈμ„ μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
    llm = HuggingFaceEndpoint(
        repo_id=model_id,  # μ‚¬μš©ν•  λͺ¨λΈ ID
        max_new_tokens=max_new_tokens,  # 생성할 μ΅œλŒ€ 토큰 수
        temperature=temperature,  # μƒ˜ν”Œλ§ μ‹œ μ˜¨λ„ μ„€μ •
        token=os.getenv("HF_TOKEN"),  # Hugging Face API 토큰 (ν™˜κ²½ λ³€μˆ˜μ—μ„œ κ°€μ Έμ˜΄)
    )
    return llm  # μ΄ˆκΈ°ν™”λœ μ–Έμ–΄ λͺ¨λΈμ„ λ°˜ν™˜ν•©λ‹ˆλ‹€.


# Streamlit μ•± 섀정을 κ΅¬μ„±ν•©λ‹ˆλ‹€.
st.set_page_config(page_title="HuggingFace ChatBot", page_icon="πŸ€—")
st.title("개인 HuggingFace 챗봇")
st.markdown(
    f"*이것은 HuggingFace transformers 라이브러리λ₯Ό μ‚¬μš©ν•˜μ—¬ ν…μŠ€νŠΈ μž…λ ₯에 λŒ€ν•œ 응닡을 μƒμ„±ν•˜λŠ” κ°„λ‹¨ν•œ μ±—λ΄‡μž…λ‹ˆλ‹€. {model_id} λͺ¨λΈμ„ μ‚¬μš©ν•©λ‹ˆλ‹€.*"
)

# 아바타에 λŒ€ν•œ μ„Έμ…˜ μƒνƒœλ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "avatars" not in st.session_state:
    st.session_state.avatars = {"user": None, "assistant": None}

# μ‚¬μš©μž ν…μŠ€νŠΈ μž…λ ₯에 λŒ€ν•œ μ„Έμ…˜ μƒνƒœλ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "user_text" not in st.session_state:
    st.session_state.user_text = None

# λͺ¨λΈ λ§€κ°œλ³€μˆ˜μ— λŒ€ν•œ μ„Έμ…˜ μƒνƒœλ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "max_response_length" not in st.session_state:
    st.session_state.max_response_length = 256

# μ‹œμŠ€ν…œ λ©”μ‹œμ§€μ— λŒ€ν•œ μ„Έμ…˜ μƒνƒœλ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "system_message" not in st.session_state:
    st.session_state.system_message = "인간 μ‚¬μš©μžμ™€ λŒ€ν™”ν•˜λŠ” μΉœμ ˆν•œ AI"

# μ‹œμž‘ λ©”μ‹œμ§€μ— λŒ€ν•œ μ„Έμ…˜ μƒνƒœλ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "starter_message" not in st.session_state:
    st.session_state.starter_message = "μ•ˆλ…•ν•˜μ„Έμš”! 였늘 무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?"


# 섀정을 μœ„ν•œ μ‚¬μ΄λ“œλ°”λ₯Ό κ΅¬μ„±ν•©λ‹ˆλ‹€.
with st.sidebar:
    st.header("μ‹œμŠ€ν…œ μ„€μ •")

    # AI μ„€μ •
    st.session_state.system_message = st.text_area(
        "μ‹œμŠ€ν…œ λ©”μ‹œμ§€", value="당신은 인간 μ‚¬μš©μžμ™€ λŒ€ν™”ν•˜λŠ” μΉœμ ˆν•œ AIμž…λ‹ˆλ‹€."
    )
    st.session_state.starter_message = st.text_area(
        "첫 번째 AI λ©”μ‹œμ§€", value="μ•ˆλ…•ν•˜μ„Έμš”! 였늘 무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?"
    )

    # λͺ¨λΈ μ„€μ •
    st.session_state.max_response_length = st.number_input("μ΅œλŒ€ 응닡 길이", value=128)

    # 아바타 선택
    st.markdown("*아바타 선택:*")
    col1, col2 = st.columns(2)
    with col1:
        st.session_state.avatars["assistant"] = st.selectbox(
            "AI 아바타", options=["πŸ€—", "πŸ’¬", "πŸ€–"], index=0
        )
    with col2:
        st.session_state.avatars["user"] = st.selectbox(
            "μ‚¬μš©μž 아바타", options=["πŸ‘€", "πŸ‘±β€β™‚οΈ", "πŸ‘¨πŸΎ", "πŸ‘©", "πŸ‘§πŸΎ"], index=0
        )
    # μ±„νŒ… 기둝 μ΄ˆκΈ°ν™” λ²„νŠΌ
    reset_history = st.button("μ±„νŒ… 기둝 μ΄ˆκΈ°ν™”")

# μ±„νŒ… 기둝을 μ΄ˆκΈ°ν™”ν•˜κ±°λ‚˜, μ΄ˆκΈ°ν™” λ²„νŠΌμ΄ λˆŒλ Έμ„ 경우 μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "chat_history" not in st.session_state or reset_history:
    st.session_state.chat_history = [
        {"role": "assistant", "content": st.session_state.starter_message}
    ]


def get_response(
    system_message,
    chat_history,
    user_text,
    eos_token_id=["User"],
    max_new_tokens=256,
    get_llm_hf_kws={},
):
    """
    챗봇 λͺ¨λΈλ‘œλΆ€ν„° 응닡을 μƒμ„±ν•©λ‹ˆλ‹€.

    λ§€κ°œλ³€μˆ˜:
        system_message (str): λŒ€ν™”μ˜ μ‹œμŠ€ν…œ λ©”μ‹œμ§€μž…λ‹ˆλ‹€.
        chat_history (list): 이전 μ±„νŒ… λ©”μ‹œμ§€ λͺ©λ‘μž…λ‹ˆλ‹€.
        user_text (str): μ‚¬μš©μžμ˜ μž…λ ₯ ν…μŠ€νŠΈμž…λ‹ˆλ‹€.
        model_id (str, optional): μ‚¬μš©ν•  Hugging Face λͺ¨λΈμ˜ IDμž…λ‹ˆλ‹€.
        eos_token_id (list, optional): λ¬Έμž₯ μ’…λ£Œ 토큰 ID λͺ©λ‘μž…λ‹ˆλ‹€.
        max_new_tokens (int, optional): 생성할 수 μžˆλŠ” μ΅œλŒ€ μƒˆ 토큰 μˆ˜μž…λ‹ˆλ‹€.
        get_llm_hf_kws (dict, optional): get_llm_hf ν•¨μˆ˜μ— 전달할 μΆ”κ°€ ν‚€μ›Œλ“œ μΈμžμž…λ‹ˆλ‹€.

    λ°˜ν™˜κ°’:
        tuple: μƒμ„±λœ 응닡과 μ—…λ°μ΄νŠΈλœ μ±„νŒ… 기둝을 ν¬ν•¨ν•˜λŠ” νŠœν”Œμž…λ‹ˆλ‹€.
    """
    # λͺ¨λΈμ„ μ„€μ •ν•©λ‹ˆλ‹€.
    hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)

    # ν”„λ‘¬ν”„νŠΈ ν…œν”Œλ¦Ώμ„ μƒμ„±ν•©λ‹ˆλ‹€.
    prompt = PromptTemplate.from_template(
        (
            "[INST] {system_message}"
            "\nν˜„μž¬ λŒ€ν™”:\n{chat_history}\n\n"
            "\nμ‚¬μš©μž: {user_text}.\n [/INST]"
            "\nAI:"
        )
    )
    # ν”„λ‘¬ν”„νŠΈλ₯Ό μ—°κ²°ν•˜μ—¬ μ±„νŒ… 체인을 λ§Œλ“­λ‹ˆλ‹€.
    chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key="content")

    # 응닡을 μƒμ„±ν•©λ‹ˆλ‹€.
    response = chat.invoke(
        input=dict(
            system_message=system_message,
            user_text=user_text,
            chat_history=chat_history,
        )
    )
    # "AI:" 접두사λ₯Ό μ œκ±°ν•©λ‹ˆλ‹€.
    response = response.split("AI:")[-1]

    # μ±„νŒ… 기둝을 μ—…λ°μ΄νŠΈν•©λ‹ˆλ‹€.
    chat_history.append({"role": "user", "content": user_text})
    chat_history.append({"role": "assistant", "content": response})
    return response, chat_history


# μ±„νŒ… μΈν„°νŽ˜μ΄μŠ€λ₯Ό μ„€μ •ν•©λ‹ˆλ‹€.
chat_interface = st.container(border=True)
with chat_interface:
    output_container = st.container()
    st.session_state.user_text = st.chat_input(
        placeholder="여기에 ν…μŠ€νŠΈλ₯Ό μž…λ ₯ν•˜μ„Έμš”."
    )

# μ±„νŒ… λ©”μ‹œμ§€λ₯Ό ν‘œμ‹œν•©λ‹ˆλ‹€.
with output_container:
    # μ±„νŒ… 기둝에 μžˆλŠ” 각 λ©”μ‹œμ§€μ— λŒ€ν•΄ λ°˜λ³΅ν•©λ‹ˆλ‹€.
    for message in st.session_state.chat_history:
        # μ‹œμŠ€ν…œ λ©”μ‹œμ§€λŠ” κ±΄λ„ˆλœλ‹ˆλ‹€.
        if message["role"] == "system":
            continue

        # μ˜¬λ°”λ₯Έ 아바타λ₯Ό μ‚¬μš©ν•˜μ—¬ μ±„νŒ… λ©”μ‹œμ§€λ₯Ό ν‘œμ‹œν•©λ‹ˆλ‹€.
        with st.chat_message(
            message["role"], avatar=st.session_state["avatars"][message["role"]]
        ):
            st.markdown(message["content"])

    # μ‚¬μš©μžκ°€ μƒˆ ν…μŠ€νŠΈλ₯Ό μž…λ ₯ν–ˆμ„ λ•Œ:
    if st.session_state.user_text:
        # μ‚¬μš©μžμ˜ μƒˆ λ©”μ‹œμ§€λ₯Ό μ¦‰μ‹œ ν‘œμ‹œν•©λ‹ˆλ‹€.
        with st.chat_message("user", avatar=st.session_state.avatars["user"]):
            st.markdown(st.session_state.user_text)

        # 응닡을 κΈ°λ‹€λ¦¬λŠ” λ™μ•ˆ μŠ€ν”Όλ„ˆ μƒνƒœ ν‘œμ‹œμ€„μ„ ν‘œμ‹œν•©λ‹ˆλ‹€.
        with st.chat_message("assistant", avatar=st.session_state.avatars["assistant"]):
            with st.spinner("생각 쀑..."):
                # μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ, μ‚¬μš©μž ν…μŠ€νŠΈ 및 기둝을 μ‚¬μš©ν•˜μ—¬ μΆ”λ‘  APIλ₯Ό ν˜ΈμΆœν•©λ‹ˆλ‹€.
                response, st.session_state.chat_history = get_response(
                    system_message=st.session_state.system_message,
                    user_text=st.session_state.user_text,
                    chat_history=st.session_state.chat_history,
                    max_new_tokens=st.session_state.max_response_length,
                )
                st.markdown(response)