first_chatbot / app.py
김탱
init
1b3c845
import os
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEndpoint
import streamlit as st
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
# μ‚¬μš©ν•  Hugging Face λͺ¨λΈ IDλ₯Ό μ •μ˜ν•©λ‹ˆλ‹€.
model_id = "mistralai/Mistral-7B-Instruct-v0.3"
# .env 파일 λ‘œλ“œ
load_dotenv()
def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1):
"""
Hugging Face 좔둠을 μœ„ν•œ μ–Έμ–΄ λͺ¨λΈμ„ λ°˜ν™˜ν•©λ‹ˆλ‹€.
λ§€κ°œλ³€μˆ˜:
- model_id (str): Hugging Face λͺ¨λΈ μ €μž₯μ†Œμ˜ IDμž…λ‹ˆλ‹€.
- max_new_tokens (int): 생성할 수 μžˆλŠ” μ΅œλŒ€ μƒˆ 토큰 μˆ˜μž…λ‹ˆλ‹€.
- temperature (float): λͺ¨λΈμ—μ„œ μƒ˜ν”Œλ§ν•  λ•Œμ˜ μ˜¨λ„ κ°’μž…λ‹ˆλ‹€.
λ°˜ν™˜κ°’:
- llm (HuggingFaceEndpoint): Hugging Face 좔둠을 μœ„ν•œ μ–Έμ–΄ λͺ¨λΈμž…λ‹ˆλ‹€.
"""
# HuggingFaceEndpointλ₯Ό μ‚¬μš©ν•˜μ—¬ μ–Έμ–΄ λͺ¨λΈμ„ μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
llm = HuggingFaceEndpoint(
repo_id=model_id, # μ‚¬μš©ν•  λͺ¨λΈ ID
max_new_tokens=max_new_tokens, # 생성할 μ΅œλŒ€ 토큰 수
temperature=temperature, # μƒ˜ν”Œλ§ μ‹œ μ˜¨λ„ μ„€μ •
token=os.getenv("HF_TOKEN"), # Hugging Face API 토큰 (ν™˜κ²½ λ³€μˆ˜μ—μ„œ κ°€μ Έμ˜΄)
)
return llm # μ΄ˆκΈ°ν™”λœ μ–Έμ–΄ λͺ¨λΈμ„ λ°˜ν™˜ν•©λ‹ˆλ‹€.
# Streamlit μ•± 섀정을 κ΅¬μ„±ν•©λ‹ˆλ‹€.
st.set_page_config(page_title="HuggingFace ChatBot", page_icon="πŸ€—")
st.title("개인 HuggingFace 챗봇")
st.markdown(
f"*이것은 HuggingFace transformers 라이브러리λ₯Ό μ‚¬μš©ν•˜μ—¬ ν…μŠ€νŠΈ μž…λ ₯에 λŒ€ν•œ 응닡을 μƒμ„±ν•˜λŠ” κ°„λ‹¨ν•œ μ±—λ΄‡μž…λ‹ˆλ‹€. {model_id} λͺ¨λΈμ„ μ‚¬μš©ν•©λ‹ˆλ‹€.*"
)
# 아바타에 λŒ€ν•œ μ„Έμ…˜ μƒνƒœλ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "avatars" not in st.session_state:
st.session_state.avatars = {"user": None, "assistant": None}
# μ‚¬μš©μž ν…μŠ€νŠΈ μž…λ ₯에 λŒ€ν•œ μ„Έμ…˜ μƒνƒœλ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "user_text" not in st.session_state:
st.session_state.user_text = None
# λͺ¨λΈ λ§€κ°œλ³€μˆ˜μ— λŒ€ν•œ μ„Έμ…˜ μƒνƒœλ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "max_response_length" not in st.session_state:
st.session_state.max_response_length = 256
# μ‹œμŠ€ν…œ λ©”μ‹œμ§€μ— λŒ€ν•œ μ„Έμ…˜ μƒνƒœλ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "system_message" not in st.session_state:
st.session_state.system_message = "인간 μ‚¬μš©μžμ™€ λŒ€ν™”ν•˜λŠ” μΉœμ ˆν•œ AI"
# μ‹œμž‘ λ©”μ‹œμ§€μ— λŒ€ν•œ μ„Έμ…˜ μƒνƒœλ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "starter_message" not in st.session_state:
st.session_state.starter_message = "μ•ˆλ…•ν•˜μ„Έμš”! 였늘 무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?"
# 섀정을 μœ„ν•œ μ‚¬μ΄λ“œλ°”λ₯Ό κ΅¬μ„±ν•©λ‹ˆλ‹€.
with st.sidebar:
st.header("μ‹œμŠ€ν…œ μ„€μ •")
# AI μ„€μ •
st.session_state.system_message = st.text_area(
"μ‹œμŠ€ν…œ λ©”μ‹œμ§€", value="당신은 인간 μ‚¬μš©μžμ™€ λŒ€ν™”ν•˜λŠ” μΉœμ ˆν•œ AIμž…λ‹ˆλ‹€."
)
st.session_state.starter_message = st.text_area(
"첫 번째 AI λ©”μ‹œμ§€", value="μ•ˆλ…•ν•˜μ„Έμš”! 였늘 무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?"
)
# λͺ¨λΈ μ„€μ •
st.session_state.max_response_length = st.number_input("μ΅œλŒ€ 응닡 길이", value=128)
# 아바타 선택
st.markdown("*아바타 선택:*")
col1, col2 = st.columns(2)
with col1:
st.session_state.avatars["assistant"] = st.selectbox(
"AI 아바타", options=["πŸ€—", "πŸ’¬", "πŸ€–"], index=0
)
with col2:
st.session_state.avatars["user"] = st.selectbox(
"μ‚¬μš©μž 아바타", options=["πŸ‘€", "πŸ‘±β€β™‚οΈ", "πŸ‘¨πŸΎ", "πŸ‘©", "πŸ‘§πŸΎ"], index=0
)
# μ±„νŒ… 기둝 μ΄ˆκΈ°ν™” λ²„νŠΌ
reset_history = st.button("μ±„νŒ… 기둝 μ΄ˆκΈ°ν™”")
# μ±„νŒ… 기둝을 μ΄ˆκΈ°ν™”ν•˜κ±°λ‚˜, μ΄ˆκΈ°ν™” λ²„νŠΌμ΄ λˆŒλ Έμ„ 경우 μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€.
if "chat_history" not in st.session_state or reset_history:
st.session_state.chat_history = [
{"role": "assistant", "content": st.session_state.starter_message}
]
def get_response(
system_message,
chat_history,
user_text,
eos_token_id=["User"],
max_new_tokens=256,
get_llm_hf_kws={},
):
"""
챗봇 λͺ¨λΈλ‘œλΆ€ν„° 응닡을 μƒμ„±ν•©λ‹ˆλ‹€.
λ§€κ°œλ³€μˆ˜:
system_message (str): λŒ€ν™”μ˜ μ‹œμŠ€ν…œ λ©”μ‹œμ§€μž…λ‹ˆλ‹€.
chat_history (list): 이전 μ±„νŒ… λ©”μ‹œμ§€ λͺ©λ‘μž…λ‹ˆλ‹€.
user_text (str): μ‚¬μš©μžμ˜ μž…λ ₯ ν…μŠ€νŠΈμž…λ‹ˆλ‹€.
model_id (str, optional): μ‚¬μš©ν•  Hugging Face λͺ¨λΈμ˜ IDμž…λ‹ˆλ‹€.
eos_token_id (list, optional): λ¬Έμž₯ μ’…λ£Œ 토큰 ID λͺ©λ‘μž…λ‹ˆλ‹€.
max_new_tokens (int, optional): 생성할 수 μžˆλŠ” μ΅œλŒ€ μƒˆ 토큰 μˆ˜μž…λ‹ˆλ‹€.
get_llm_hf_kws (dict, optional): get_llm_hf ν•¨μˆ˜μ— 전달할 μΆ”κ°€ ν‚€μ›Œλ“œ μΈμžμž…λ‹ˆλ‹€.
λ°˜ν™˜κ°’:
tuple: μƒμ„±λœ 응닡과 μ—…λ°μ΄νŠΈλœ μ±„νŒ… 기둝을 ν¬ν•¨ν•˜λŠ” νŠœν”Œμž…λ‹ˆλ‹€.
"""
# λͺ¨λΈμ„ μ„€μ •ν•©λ‹ˆλ‹€.
hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
# ν”„λ‘¬ν”„νŠΈ ν…œν”Œλ¦Ώμ„ μƒμ„±ν•©λ‹ˆλ‹€.
prompt = PromptTemplate.from_template(
(
"[INST] {system_message}"
"\nν˜„μž¬ λŒ€ν™”:\n{chat_history}\n\n"
"\nμ‚¬μš©μž: {user_text}.\n [/INST]"
"\nAI:"
)
)
# ν”„λ‘¬ν”„νŠΈλ₯Ό μ—°κ²°ν•˜μ—¬ μ±„νŒ… 체인을 λ§Œλ“­λ‹ˆλ‹€.
chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key="content")
# 응닡을 μƒμ„±ν•©λ‹ˆλ‹€.
response = chat.invoke(
input=dict(
system_message=system_message,
user_text=user_text,
chat_history=chat_history,
)
)
# "AI:" 접두사λ₯Ό μ œκ±°ν•©λ‹ˆλ‹€.
response = response.split("AI:")[-1]
# μ±„νŒ… 기둝을 μ—…λ°μ΄νŠΈν•©λ‹ˆλ‹€.
chat_history.append({"role": "user", "content": user_text})
chat_history.append({"role": "assistant", "content": response})
return response, chat_history
# μ±„νŒ… μΈν„°νŽ˜μ΄μŠ€λ₯Ό μ„€μ •ν•©λ‹ˆλ‹€.
chat_interface = st.container(border=True)
with chat_interface:
output_container = st.container()
st.session_state.user_text = st.chat_input(
placeholder="여기에 ν…μŠ€νŠΈλ₯Ό μž…λ ₯ν•˜μ„Έμš”."
)
# μ±„νŒ… λ©”μ‹œμ§€λ₯Ό ν‘œμ‹œν•©λ‹ˆλ‹€.
with output_container:
# μ±„νŒ… 기둝에 μžˆλŠ” 각 λ©”μ‹œμ§€μ— λŒ€ν•΄ λ°˜λ³΅ν•©λ‹ˆλ‹€.
for message in st.session_state.chat_history:
# μ‹œμŠ€ν…œ λ©”μ‹œμ§€λŠ” κ±΄λ„ˆλœλ‹ˆλ‹€.
if message["role"] == "system":
continue
# μ˜¬λ°”λ₯Έ 아바타λ₯Ό μ‚¬μš©ν•˜μ—¬ μ±„νŒ… λ©”μ‹œμ§€λ₯Ό ν‘œμ‹œν•©λ‹ˆλ‹€.
with st.chat_message(
message["role"], avatar=st.session_state["avatars"][message["role"]]
):
st.markdown(message["content"])
# μ‚¬μš©μžκ°€ μƒˆ ν…μŠ€νŠΈλ₯Ό μž…λ ₯ν–ˆμ„ λ•Œ:
if st.session_state.user_text:
# μ‚¬μš©μžμ˜ μƒˆ λ©”μ‹œμ§€λ₯Ό μ¦‰μ‹œ ν‘œμ‹œν•©λ‹ˆλ‹€.
with st.chat_message("user", avatar=st.session_state.avatars["user"]):
st.markdown(st.session_state.user_text)
# 응닡을 κΈ°λ‹€λ¦¬λŠ” λ™μ•ˆ μŠ€ν”Όλ„ˆ μƒνƒœ ν‘œμ‹œμ€„μ„ ν‘œμ‹œν•©λ‹ˆλ‹€.
with st.chat_message("assistant", avatar=st.session_state.avatars["assistant"]):
with st.spinner("생각 쀑..."):
# μ‹œμŠ€ν…œ ν”„λ‘¬ν”„νŠΈ, μ‚¬μš©μž ν…μŠ€νŠΈ 및 기둝을 μ‚¬μš©ν•˜μ—¬ μΆ”λ‘  APIλ₯Ό ν˜ΈμΆœν•©λ‹ˆλ‹€.
response, st.session_state.chat_history = get_response(
system_message=st.session_state.system_message,
user_text=st.session_state.user_text,
chat_history=st.session_state.chat_history,
max_new_tokens=st.session_state.max_response_length,
)
st.markdown(response)