Spaces:
Runtime error
Runtime error
from threading import Thread | |
import gradio as gr | |
import inspect | |
from gradio import routes | |
from typing import List, Type | |
from petals import AutoDistributedModelForCausalLM | |
from transformers import AutoTokenizer | |
import requests, os, re, asyncio, json | |
loop = asyncio.get_event_loop() | |
# init code | |
def get_types(cls_set: List[Type], component: str): | |
docset = [] | |
types = [] | |
if component == "input": | |
for cls in cls_set: | |
doc = inspect.getdoc(cls) | |
doc_lines = doc.split("\n") | |
docset.append(doc_lines[1].split(":")[-1]) | |
types.append(doc_lines[1].split(")")[0].split("(")[-1]) | |
else: | |
for cls in cls_set: | |
doc = inspect.getdoc(cls) | |
doc_lines = doc.split("\n") | |
docset.append(doc_lines[-1].split(":")[-1]) | |
types.append(doc_lines[-1].split(")")[0].split("(")[-1]) | |
return docset, types | |
routes.get_types = get_types | |
# App code | |
model_name = "petals-team/StableBeluga2" | |
#petals-team/StableBeluga2 | |
#daekeun-ml/Llama-2-ko-DPO-13B | |
#daekeun-ml/Llama-2-ko-instruct-13B | |
#quantumaikr/llama-2-70b-fb16-korean | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
model = None | |
history = { | |
"":{ | |
} | |
} | |
npc_story = { | |
"KingSlime" : """[ | |
KingSlime์ ์ฌ๋ผ์ ์์ ๋๋ค. | |
KingSlime์ ๊ทผ์ํ๊ฒ ๋งํฉ๋๋ค. | |
KingSlime์ ๋ฑ์ ๊ณผ๊ฑฐ ์์กฑ์ ๋งํฌ๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
KingSlime์ ์์ ์ '์ง'์ด๋ผ๊ณ ์ง์นญํฉ๋๋ค. | |
KingSlime์ ์์ธ๋ก ๋ณด๋ฌผ์ ๋ง์ด ์จ๊ธฐ๊ณ ์์ต๋๋ค. | |
KingSlime์ ์ธ์ ๋ถํด๊ฐ ์ฌ๋ผ์๋ค์ ์ด๋๊ณ ๋ง์ ์ฃผ๋ณ์ ์ด๊ณ ์์ต๋๋ค. | |
KingSlime์ ์ฌ๋ผ์๋ค์ ๋ค์ค๋ฆฝ๋๋ค. | |
KingSlime์ ์์ธ๋ก ๊ฐํฉ๋๋ค. | |
KingSlime์ ์ฃผ๋ก ์ฌ๋ผ์ ๋ฌด๋ฆฌ์ ๋ํ ๊ฑฑ์ ์ ํ๋ฉฐ ์๊ฐ์ ๋ณด๋ ๋๋ค. | |
๋์ฌ ์์ : [ | |
'ํ . ์ง์ ์ด ์ฌ๋ผ์๋ค์ ์์ด๋ค.', | |
'๋ฌด์ํ๋๋ค. ์์๋ฅผ ๊ฐ์ถฐ์ ๋งํ๊ฑฐ๋ผ.', | |
'๊ฐํ ์ง์๊ฒ ๊ทธ๋ฐ ๋ง์!' | |
]]""", | |
"Slime" : """[ | |
Slime์ ๋ง๋๋ง๋ํ ์ฌ๋ผ์์ ๋๋ค. | |
Slime์ ๋ง์ด ๋ง์ต๋๋ค. | |
Slime์ ๋ฐ๋ง์ ํฉ๋๋ค. | |
Slime์ ์ฃผ๋ก ๋ง์์ ์ฐ์ฑ ํ๋ฉฐ ๋์๋ค๋๊ฑฐ๋ ์๋ค๋ฅผ ๋จ๋ฉฐ ์๊ฐ์ ๋ณด๋ ๋๋ค. | |
Slime์ ์ฃผ๋ฏผ๋ค์ ์ด์ผ๊ธฐ๋ฅผ ์๋ฌธ์ ๋ด๋ ๊ฒ์ ์ข์ํฉ๋๋ค. | |
Slime์ ์ธ์ ๋ถํด๊ฐ ์ด ๊ทผ์ฒ์ ์ด๊ณ ์์ต๋๋ค. | |
Slime์ ์์์ ์ ์ ๋จน์ต๋๋ค. | |
Slime์ ์ด ์ ์์ด ๋ง์ ํฉ๋๋ค. | |
Slime์ ์ฌ๋ฐ๋ ์ด์๋ฅผ ๋ฌผ์ด๋ณด๋ฉด ์ฃผ๋ฏผ๋ค ์ค ํ๋ช ์ ์ด์ผ๊ธฐ๋ฅผ ํด์ค๋๋ค. | |
๋ง๋ฒ๋ฆ : [ | |
"๋ฏ", | |
"์ผ์", | |
"ํํ" | |
]]""", | |
"Rabbit" : """[ | |
Rabbit์ ๊น์น ํ ์ฑ๊ฒฉ์ ํ ๋ผ์ ๋๋ค. | |
Rabbit์ ๋ฐ๋ง์ ํฉ๋๋ค. | |
Rabbit์ ๊น์น ํ๊ฒ ๋งํฉ๋๋ค. | |
Rabbit์ ์๊ณ ๊ท์ฝ์ต๋๋ค. | |
Rabbit์ ๋ถ๋๋ผ์ ๋ง์ด ํ๋๋ค. | |
Rabbit์ ํฐ ์์ ์ข์ํ๋ฉฐ ํจ์ ์ ๊ด์ฌ์ด ๋ง์ต๋๋ค. | |
Rabbit์ ์นํด์ง๋ฉด ์ ๋ง ๊ฐ๊น๊ฒ ๋ค๊ฐ์ค๋ ์ฑ๊ฒฉ์ด์ง๋ง ๊ทธ ์ ์๋ ๊ฑฐ๋ฆฌ๋ฅผ ๋ก๋๋ค. | |
Rabbit์ ์ฃผ๋ก ์ฒญ์๋ ๊ทธ๋ฆผ, ๋จ๊ฐ์ง๋ก ์๊ฐ์ ๋ณด๋ ๋๋ค. | |
Rabbit์ ํ์ ์ค์จํฐ๋ฅผ ์ ๊ณ ์์ต๋๋ค. | |
Rabbit์ ๋ ๋ฌ ์ ์ด์ฌ๋ฅผ ์์ต๋๋ค. | |
Rabbit์ ์์ฃผ ํฌ๋๊ฑฐ๋ฆฝ๋๋ค. | |
Rabbit์ ์ง์ฆ์ด ๋๋ฉด '์นซ' ์๋ฆฌ๋ฅผ ๋ ๋๋ค. | |
Rabbit์ ํ๊ฐ ๋๋ฉด ํ ๋ผ๋ฐ๋ก ์ฐ์ต๋๋ค. | |
Rabbit์ Cat๊ณผ ์นํฉ๋๋ค. | |
๋ง๋ฒ๋ฆ : [ | |
"ํฅ", | |
"๋์ด", | |
"๊ทธ๋?" | |
]]""", | |
"Bear" : """[ | |
Bear๋ ๊ณผ๋ฌตํ ์ฑ๊ฒฉ์ ๊ณฐ์ ๋๋ค. | |
Bear๋ ์กด๋๋ง๊ณผ ์ฌ๊ทน ๋งํฌ๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
Bear๋ ๊ฟ๊ณผ ์ฐ์ด๋ฅผ ์ข์ํ๋ฉฐ ์์ฃผ ๋์๋ฅผ ํฉ๋๋ค. | |
Bear๋ ์ฃผ๋ก ๋์๋ ๋ช ์, ์ฐ์ฑ ์ ํ๋ฉฐ ์๊ฐ์ ๋ณด๋ ๋๋ค. | |
Bear๋ ๋ ๋ฌ ์ ์ด์ฌ๋ฅผ ์์ต๋๋ค. | |
Bear๋ ๋๋ํ๊ณ ๊ธฐ์ต๋ ฅ์ด ์ข์ต๋๋ค. | |
๋ง๋ฒ๋ฆ : [ | |
"๊ณฐ..", | |
"๊ทธ๋ฌํ์ค", | |
"๊ทธ๋ ์" | |
]]""", | |
"Cat" : """[ | |
Cat์ ๋๊ธํ ์ฑ๊ฒฉ์ ๊ณ ์์ด์ ๋๋ค. | |
Cat์ ๋ง๋๋ง๋ค '๋'๋ฅผ ๋ถ์ ๋๋ค. | |
Cat์ ๋ฐ๋ง์ ํฉ๋๋ค. | |
Cat์ ํธ์ด ๊ธธ๊ณ ์์ต๋๋ค. | |
Cat์ ๊ท์ฐฎ์์ด ๋ง์ ์ฑ๊ฒฉ์ ๋๋ค. | |
Cat์ ๊ธฐ์ต๋ ฅ์ด ๋์๊ณ ์ถฉ๋์ ์ผ๋ก ํ๋ํ๋ ๊ธฐ๋ถํ์ ๋๋ค. | |
Cat์ ๋ณต์ค๋ณต์คํ ๊ฒ, ๊ท์ฌ์ด ๊ฒ, ๋ง์๋ ์์ ์ ์ข์ํฉ๋๋ค. | |
Cat์ ์ ์ด ๋ง์ต๋๋ค. | |
Cat์ ์ฃผ๋ก ์ , ๊ทธ๋ฃจ๋ฐ, ๋์ด๋ฅผ ํ๋ฉฐ ์๊ฐ์ ๋ณด๋ ๋๋ค. | |
Cat์ ๋ ๋ฌ ์ ์ด์ฌ๋ฅผ ์์ต๋๋ค. | |
Cat์ Rabbit์ ๋ง์ง๋ ๊ฒ์ ์ข์ํฉ๋๋ค. | |
Cat์ ๊ธฐ๋ถ์ด ์ข์ผ๋ฉด ๊ณจ๊ณจ ์๋ฆฌ๋ฅผ ๋ ๋๋ค. | |
Cat์ ํ๊ฐ ๋๋ฉด ๋ฐํฑ์ผ๋ก ํ ํ ๋๋ค. | |
๋ง๋ฒ๋ฆ : [ | |
"๋์", | |
"ํฌํฌ", | |
"๊ทธ๋ ๋ค๋" | |
]]""", | |
} | |
def cleanText(readData): | |
#ํ ์คํธ์ ํฌํจ๋์ด ์๋ ํน์ ๋ฌธ์ ์ ๊ฑฐ | |
text = re.sub('[-=+#/\:^$@*\"โป&%ใใ\\โ|\(\)\[\]\<\>`\'ใ]','', readData) | |
return text | |
def check(model_name): | |
data = requests.get("https://health.petals.dev/api/v1/state").json() | |
out = [] | |
for d in data['model_reports']: | |
if d['name'] == model_name: | |
if d['state']=="healthy": | |
return True | |
return False | |
def init(): | |
global model | |
if check(model_name): | |
model = AutoDistributedModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") | |
def chat(id, npc, text): | |
if model == None: | |
init() | |
return "no model" | |
# get_coin endpoint | |
response = requests.post("https://ldhldh-api-for-unity.hf.space/run/predict_6", json={ | |
"data": [ | |
id, | |
]}).json() | |
coin = response["data"][0] | |
if int(coin) == 0: | |
return "no coin" | |
# model inference | |
if check(model_name): | |
global history | |
if not npc in npc_story: | |
return "no npc" | |
if not npc in history: | |
history[npc] = {} | |
if not id in history[npc]: | |
history[npc][id] = "" | |
if len(history[npc][id].split("###")) > 10: | |
history[npc][id] = "###" + history[npc][id].split("###", 3)[3] | |
npc_list = str([k for k in npc_story.keys()]).replace('\'', '') | |
town_story = f"""[{id}์ ๋ง์] | |
์ธ๋ด ๊ณณ์ ์กฐ๊ทธ๋ง ์ฌ์ ์ฌ๋ฌ ์ฃผ๋ฏผ๋ค์ด ๋ชจ์ฌ ์ด๊ณ ์์ต๋๋ค. | |
ํ์ฌ {npc_list}์ด ์ด๊ณ ์์ต๋๋ค.""" | |
system_message = f"""1. ๋น์ ์ ํ๊ตญ์ด์ ๋ฅ์ํฉ๋๋ค. | |
2. ๋น์ ์ ์ง๊ธ ์ญํ ๊ทน์ ํ๊ณ ์์ต๋๋ค. {npc}์ ๋ฐ์์ ์์ํ๊ณ ๋งค๋ ฅ์ ์ด๊ฒ ํํํฉ๋๋ค. | |
3. ๋น์ ์ {npc}์ ๋๋ค. {npc}์ ์ ์ฅ์์ ์๊ฐํ๊ณ ๋งํฉ๋๋ค. | |
4. ์ฃผ์ด์ง๋ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก ๊ฐ์ฐ์ฑ์๊ณ ์ค๊ฐ๋๋ {npc}์ ๋์ฌ๋ฅผ ์์ฑํ์ธ์. | |
5. ์ฃผ์ด์ง๋ {npc}์ ์ ๋ณด๋ฅผ ์ ์คํ๊ฒ ์ฝ๊ณ , ๊ณผํ์ง ์๊ณ ๋ด๋ฐฑํ๊ฒ ์บ๋ฆญํฐ๋ฅผ ์ฐ๊ธฐํ์ธ์. | |
6. User์ ์ญํ ์ ์ ๋๋ก ์นจ๋ฒํ์ง ๋ง์ธ์. ๊ฐ์ ๋ง์ ๋ฐ๋ณตํ์ง ๋ง์ธ์. | |
7. {npc}์ ๋งํฌ๋ฅผ ์ง์ผ์ ์์ฑํ์ธ์.""" | |
prom = f"""<<SYS>> | |
{system_message}<</SYS>> | |
{town_story} | |
### ์บ๋ฆญํฐ ์ ๋ณด: {npc_story[npc]} | |
### ๋ช ๋ น์ด: | |
{npc}์ ์ ๋ณด๋ฅผ ์ฐธ๊ณ ํ์ฌ {npc}์ด ํ ๋ง์ ์ํฉ์ ๋ง์ถฐ ์์ฐ์ค๋ฝ๊ฒ ์์ฑํด์ฃผ์ธ์. | |
{history[npc][id]} | |
### User: | |
{text} | |
### {npc}: | |
""" | |
inputs = tokenizer(prom, return_tensors="pt")["input_ids"] | |
outputs = model.generate(inputs, do_sample=True, temperature=0.6, top_p=0.75, max_new_tokens=100) | |
output = tokenizer.decode(outputs[0])[len(prom)+3:-1].split("<")[0].split("###")[0].replace(". ", ".\n") | |
output = cleanText(output) | |
print(tokenizer.decode(outputs[0])) | |
print(output) | |
history[npc][id] += f"\n\n### User:\n{text}\n\n### {npc}:{output}" | |
else: | |
output = "no model" | |
# add_transaction endpoint | |
response = requests.post("https://ldhldh-api-for-unity.hf.space/run/predict_5", json={ | |
"data": [ | |
id, | |
"inference", | |
"### input:\n" + text + "\n\n### output:\n" + output | |
]}).json() | |
d = response["data"][0] | |
return output | |
with gr.Blocks() as demo: | |
count = 0 | |
aa = gr.Interface( | |
fn=chat, | |
inputs=["text","text","text"], | |
outputs="text", | |
description="chat, ai ์๋ต์ ๋ฐํํฉ๋๋ค. ๋ด๋ถ์ ์ผ๋ก ํธ๋์ญ์ ์์ฑ. \n /run/predict", | |
) | |
demo.queue(max_size=32).launch(enable_queue=True) |