Spaces:
Sleeping
Sleeping
import os | |
from dotenv import load_dotenv | |
from langchain_huggingface import HuggingFaceEndpoint | |
import streamlit as st | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
# μ¬μ©ν Hugging Face λͺ¨λΈ IDλ₯Ό μ μν©λλ€. | |
model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
# .env νμΌ λ‘λ | |
load_dotenv() | |
def get_llm_hf_inference(model_id=model_id, max_new_tokens=128, temperature=0.1): | |
""" | |
Hugging Face μΆλ‘ μ μν μΈμ΄ λͺ¨λΈμ λ°νν©λλ€. | |
맀κ°λ³μ: | |
- model_id (str): Hugging Face λͺ¨λΈ μ μ₯μμ IDμ λλ€. | |
- max_new_tokens (int): μμ±ν μ μλ μ΅λ μ ν ν° μμ λλ€. | |
- temperature (float): λͺ¨λΈμμ μνλ§ν λμ μ¨λ κ°μ λλ€. | |
λ°νκ°: | |
- llm (HuggingFaceEndpoint): Hugging Face μΆλ‘ μ μν μΈμ΄ λͺ¨λΈμ λλ€. | |
""" | |
# HuggingFaceEndpointλ₯Ό μ¬μ©νμ¬ μΈμ΄ λͺ¨λΈμ μ΄κΈ°νν©λλ€. | |
llm = HuggingFaceEndpoint( | |
repo_id=model_id, # μ¬μ©ν λͺ¨λΈ ID | |
max_new_tokens=max_new_tokens, # μμ±ν μ΅λ ν ν° μ | |
temperature=temperature, # μνλ§ μ μ¨λ μ€μ | |
token=os.getenv("HF_TOKEN"), # Hugging Face API ν ν° (νκ²½ λ³μμμ κ°μ Έμ΄) | |
) | |
return llm # μ΄κΈ°νλ μΈμ΄ λͺ¨λΈμ λ°νν©λλ€. | |
# Streamlit μ± μ€μ μ ꡬμ±ν©λλ€. | |
st.set_page_config(page_title="HuggingFace ChatBot", page_icon="π€") | |
st.title("κ°μΈ HuggingFace μ±λ΄") | |
st.markdown( | |
f"*μ΄κ²μ HuggingFace transformers λΌμ΄λΈλ¬λ¦¬λ₯Ό μ¬μ©νμ¬ ν μ€νΈ μ λ ₯μ λν μλ΅μ μμ±νλ κ°λ¨ν μ±λ΄μ λλ€. {model_id} λͺ¨λΈμ μ¬μ©ν©λλ€.*" | |
) | |
# μλ°νμ λν μΈμ μνλ₯Ό μ΄κΈ°νν©λλ€. | |
if "avatars" not in st.session_state: | |
st.session_state.avatars = {"user": None, "assistant": None} | |
# μ¬μ©μ ν μ€νΈ μ λ ₯μ λν μΈμ μνλ₯Ό μ΄κΈ°νν©λλ€. | |
if "user_text" not in st.session_state: | |
st.session_state.user_text = None | |
# λͺ¨λΈ 맀κ°λ³μμ λν μΈμ μνλ₯Ό μ΄κΈ°νν©λλ€. | |
if "max_response_length" not in st.session_state: | |
st.session_state.max_response_length = 256 | |
# μμ€ν λ©μμ§μ λν μΈμ μνλ₯Ό μ΄κΈ°νν©λλ€. | |
if "system_message" not in st.session_state: | |
st.session_state.system_message = "μΈκ° μ¬μ©μμ λννλ μΉμ ν AI" | |
# μμ λ©μμ§μ λν μΈμ μνλ₯Ό μ΄κΈ°νν©λλ€. | |
if "starter_message" not in st.session_state: | |
st.session_state.starter_message = "μλ νμΈμ! μ€λ 무μμ λμλ릴κΉμ?" | |
# μ€μ μ μν μ¬μ΄λλ°λ₯Ό ꡬμ±ν©λλ€. | |
with st.sidebar: | |
st.header("μμ€ν μ€μ ") | |
# AI μ€μ | |
st.session_state.system_message = st.text_area( | |
"μμ€ν λ©μμ§", value="λΉμ μ μΈκ° μ¬μ©μμ λννλ μΉμ ν AIμ λλ€." | |
) | |
st.session_state.starter_message = st.text_area( | |
"첫 λ²μ§Έ AI λ©μμ§", value="μλ νμΈμ! μ€λ 무μμ λμλ릴κΉμ?" | |
) | |
# λͺ¨λΈ μ€μ | |
st.session_state.max_response_length = st.number_input("μ΅λ μλ΅ κΈΈμ΄", value=128) | |
# μλ°ν μ ν | |
st.markdown("*μλ°ν μ ν:*") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.session_state.avatars["assistant"] = st.selectbox( | |
"AI μλ°ν", options=["π€", "π¬", "π€"], index=0 | |
) | |
with col2: | |
st.session_state.avatars["user"] = st.selectbox( | |
"μ¬μ©μ μλ°ν", options=["π€", "π±ββοΈ", "π¨πΎ", "π©", "π§πΎ"], index=0 | |
) | |
# μ±ν κΈ°λ‘ μ΄κΈ°ν λ²νΌ | |
reset_history = st.button("μ±ν κΈ°λ‘ μ΄κΈ°ν") | |
# μ±ν κΈ°λ‘μ μ΄κΈ°ννκ±°λ, μ΄κΈ°ν λ²νΌμ΄ λλ Έμ κ²½μ° μ΄κΈ°νν©λλ€. | |
if "chat_history" not in st.session_state or reset_history: | |
st.session_state.chat_history = [ | |
{"role": "assistant", "content": st.session_state.starter_message} | |
] | |
def get_response( | |
system_message, | |
chat_history, | |
user_text, | |
eos_token_id=["User"], | |
max_new_tokens=256, | |
get_llm_hf_kws={}, | |
): | |
""" | |
μ±λ΄ λͺ¨λΈλ‘λΆν° μλ΅μ μμ±ν©λλ€. | |
맀κ°λ³μ: | |
system_message (str): λνμ μμ€ν λ©μμ§μ λλ€. | |
chat_history (list): μ΄μ μ±ν λ©μμ§ λͺ©λ‘μ λλ€. | |
user_text (str): μ¬μ©μμ μ λ ₯ ν μ€νΈμ λλ€. | |
model_id (str, optional): μ¬μ©ν Hugging Face λͺ¨λΈμ IDμ λλ€. | |
eos_token_id (list, optional): λ¬Έμ₯ μ’ λ£ ν ν° ID λͺ©λ‘μ λλ€. | |
max_new_tokens (int, optional): μμ±ν μ μλ μ΅λ μ ν ν° μμ λλ€. | |
get_llm_hf_kws (dict, optional): get_llm_hf ν¨μμ μ λ¬ν μΆκ° ν€μλ μΈμμ λλ€. | |
λ°νκ°: | |
tuple: μμ±λ μλ΅κ³Ό μ λ°μ΄νΈλ μ±ν κΈ°λ‘μ ν¬ν¨νλ ννμ λλ€. | |
""" | |
# λͺ¨λΈμ μ€μ ν©λλ€. | |
hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1) | |
# ν둬ννΈ ν νλ¦Ώμ μμ±ν©λλ€. | |
prompt = PromptTemplate.from_template( | |
( | |
"[INST] {system_message}" | |
"\nνμ¬ λν:\n{chat_history}\n\n" | |
"\nμ¬μ©μ: {user_text}.\n [/INST]" | |
"\nAI:" | |
) | |
) | |
# ν둬ννΈλ₯Ό μ°κ²°νμ¬ μ±ν 체μΈμ λ§λλλ€. | |
chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key="content") | |
# μλ΅μ μμ±ν©λλ€. | |
response = chat.invoke( | |
input=dict( | |
system_message=system_message, | |
user_text=user_text, | |
chat_history=chat_history, | |
) | |
) | |
# "AI:" μ λμ¬λ₯Ό μ κ±°ν©λλ€. | |
response = response.split("AI:")[-1] | |
# μ±ν κΈ°λ‘μ μ λ°μ΄νΈν©λλ€. | |
chat_history.append({"role": "user", "content": user_text}) | |
chat_history.append({"role": "assistant", "content": response}) | |
return response, chat_history | |
# μ±ν μΈν°νμ΄μ€λ₯Ό μ€μ ν©λλ€. | |
chat_interface = st.container(border=True) | |
with chat_interface: | |
output_container = st.container() | |
st.session_state.user_text = st.chat_input( | |
placeholder="μ¬κΈ°μ ν μ€νΈλ₯Ό μ λ ₯νμΈμ." | |
) | |
# μ±ν λ©μμ§λ₯Ό νμν©λλ€. | |
with output_container: | |
# μ±ν κΈ°λ‘μ μλ κ° λ©μμ§μ λν΄ λ°λ³΅ν©λλ€. | |
for message in st.session_state.chat_history: | |
# μμ€ν λ©μμ§λ 건λλλλ€. | |
if message["role"] == "system": | |
continue | |
# μ¬λ°λ₯Έ μλ°νλ₯Ό μ¬μ©νμ¬ μ±ν λ©μμ§λ₯Ό νμν©λλ€. | |
with st.chat_message( | |
message["role"], avatar=st.session_state["avatars"][message["role"]] | |
): | |
st.markdown(message["content"]) | |
# μ¬μ©μκ° μ ν μ€νΈλ₯Ό μ λ ₯νμ λ: | |
if st.session_state.user_text: | |
# μ¬μ©μμ μ λ©μμ§λ₯Ό μ¦μ νμν©λλ€. | |
with st.chat_message("user", avatar=st.session_state.avatars["user"]): | |
st.markdown(st.session_state.user_text) | |
# μλ΅μ κΈ°λ€λ¦¬λ λμ μ€νΌλ μν νμμ€μ νμν©λλ€. | |
with st.chat_message("assistant", avatar=st.session_state.avatars["assistant"]): | |
with st.spinner("μκ° μ€..."): | |
# μμ€ν ν둬ννΈ, μ¬μ©μ ν μ€νΈ λ° κΈ°λ‘μ μ¬μ©νμ¬ μΆλ‘ APIλ₯Ό νΈμΆν©λλ€. | |
response, st.session_state.chat_history = get_response( | |
system_message=st.session_state.system_message, | |
user_text=st.session_state.user_text, | |
chat_history=st.session_state.chat_history, | |
max_new_tokens=st.session_state.max_response_length, | |
) | |
st.markdown(response) | |