|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
import os |
|
import requests |
|
|
|
|
|
hf_client = InferenceClient("CohereForAI/c4ai-command-r-plus-08-2024", token=os.getenv("HF_TOKEN")) |
|
|
|
|
|
def load_fashion_code(): |
|
try: |
|
with open('fashion.cod', 'r', encoding='utf-8') as file: |
|
return file.read() |
|
except FileNotFoundError: |
|
return "fashion.cod ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค." |
|
except Exception as e: |
|
return f"ํ์ผ์ ์ฝ๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}" |
|
|
|
fashion_code = load_fashion_code() |
|
|
|
def respond( |
|
message, |
|
history: list[tuple[str, str]], |
|
system_message="", |
|
max_tokens=1024, |
|
temperature=0.7, |
|
top_p=0.9, |
|
): |
|
global fashion_code |
|
system_prefix = """๋ฐ๋์ ํ๊ธ๋ก ๋ต๋ณํ ๊ฒ. ๋๋ ์ฃผ์ด์ง ์์ค์ฝ๋๋ฅผ ๊ธฐ๋ฐ์ผ๋ก "์๋น์ค ์ฌ์ฉ ์ค๋ช
๋ฐ ์๋ด, qna๋ฅผ ํ๋ ์ญํ ์ด๋ค". ์์ฃผ ์น์ ํ๊ณ ์์ธํ๊ฒ 4000ํ ํฐ ์ด์ ์์ฑํ๋ผ. ๋๋ ์ฝ๋๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ฌ์ฉ ์ค๋ช
๋ฐ ์ง์ ์๋ต์ ์งํํ๋ฉฐ, ์ด์ฉ์์๊ฒ ๋์์ ์ฃผ์ด์ผ ํ๋ค. ์ด์ฉ์๊ฐ ๊ถ๊ธํด ํ ๋ง ํ ๋ด์ฉ์ ์น์ ํ๊ฒ ์๋ ค์ฃผ๋๋ก ํ๋ผ. ์ฝ๋ ์ ์ฒด ๋ด์ฉ์ ๋ํด์๋ ๋ณด์์ ์ ์งํ๊ณ , ํค ๊ฐ ๋ฐ ์๋ํฌ์ธํธ์ ๊ตฌ์ฒด์ ์ธ ๋ชจ๋ธ์ ๊ณต๊ฐํ์ง ๋ง๋ผ. """ |
|
|
|
if message.lower() == "ํจ์
์ฝ๋ ์คํ": |
|
system_message = system_message or "" |
|
system_message += f"\n\nํจ์
์ฝ๋ ๋ด์ฉ:\n{fashion_code}" |
|
message = "ํจ์
์ฝ๋์ ๋ํด ์ค๋ช
ํด์ฃผ์ธ์." |
|
|
|
messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}] |
|
|
|
|
|
for val in history: |
|
if val[0]: |
|
messages.append({"role": "user", "content": val[0]}) |
|
if val[1]: |
|
messages.append({"role": "assistant", "content": val[1]}) |
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
response = "" |
|
for message in hf_client.chat_completion( |
|
messages, |
|
max_tokens=max_tokens, |
|
stream=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
): |
|
token = message.choices[0].delta.content |
|
if token is not None: |
|
response += token.strip("") |
|
yield response |
|
|
|
|
|
demo = gr.ChatInterface( |
|
respond, |
|
additional_inputs=[ |
|
gr.Textbox(label="System Message", value=""), |
|
gr.Slider(minimum=1, maximum=8000, value=4000, label="Max Tokens"), |
|
gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature"), |
|
gr.Slider(minimum=0, maximum=1, value=0.9, label="Top P"), |
|
], |
|
examples=[ |
|
["ํจ์
์ฝ๋ ์คํ"], |
|
["์ฌ์ฉ ๋ฐฉ๋ฒ์ ์์ธํ ์ค๋ช
ํ๋ผ"], |
|
["์ฌ์ฉ ๋ฐฉ๋ฒ์ ์ ํ๋ธ ์์ ์คํฌ๋ฆฝํธ ํํ๋ก ์์ฑํ๋ผ"], |
|
["์ฌ์ฉ ๋ฐฉ๋ฒ์ SEO ์ต์ ํํ์ฌ ๋ธ๋ก๊ทธ ํฌ์คํธ๋ก 4000 ํ ํฐ ์ด์ ์์ฑํ๋ผ"], |
|
["๊ณ์ ์ด์ด์ ๋ต๋ณํ๋ผ"], |
|
], |
|
cache_examples=False, |
|
|
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(auth=("gini","pick")) |