api_for_chat / app.py
ldhldh's picture
Update app.py
8b59f8d
raw
history blame
8.12 kB
from threading import Thread
import gradio as gr
import inspect
from gradio import routes
from typing import List, Type
from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer
import requests, os, re, asyncio, json
loop = asyncio.get_event_loop()
# init code
def get_types(cls_set: List[Type], component: str):
docset = []
types = []
if component == "input":
for cls in cls_set:
doc = inspect.getdoc(cls)
doc_lines = doc.split("\n")
docset.append(doc_lines[1].split(":")[-1])
types.append(doc_lines[1].split(")")[0].split("(")[-1])
else:
for cls in cls_set:
doc = inspect.getdoc(cls)
doc_lines = doc.split("\n")
docset.append(doc_lines[-1].split(":")[-1])
types.append(doc_lines[-1].split(")")[0].split("(")[-1])
return docset, types
routes.get_types = get_types
# App code
model_name = "petals-team/StableBeluga2"
#petals-team/StableBeluga2
#daekeun-ml/Llama-2-ko-DPO-13B
#daekeun-ml/Llama-2-ko-instruct-13B
#quantumaikr/llama-2-70b-fb16-korean
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = None
history = {
"":{
}
}
npc_story = {
"KingSlime" : """[
KingSlime์€ ์Šฌ๋ผ์ž„ ์™•์ž…๋‹ˆ๋‹ค.
KingSlime์€ ๊ทผ์—„ํ•˜๊ฒŒ ๋งํ•ฉ๋‹ˆ๋‹ค.
KingSlime์€ ๋“ฑ์˜ ๊ณผ๊ฑฐ ์™•์กฑ์˜ ๋งํˆฌ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
KingSlime์€ ์ž์‹ ์„ '์ง'์ด๋ผ๊ณ  ์ง€์นญํ•ฉ๋‹ˆ๋‹ค.
KingSlime์€ ์˜์™ธ๋กœ ๋ณด๋ฌผ์„ ๋งŽ์ด ์ˆจ๊ธฐ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
KingSlime์€ ์–ธ์ œ๋ถ€ํ„ด๊ฐ€ ์Šฌ๋ผ์ž„๋“ค์„ ์ด๋Œ๊ณ  ๋งˆ์„ ์ฃผ๋ณ€์— ์‚ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
KingSlime์€ ์Šฌ๋ผ์ž„๋“ค์„ ๋‹ค์Šค๋ฆฝ๋‹ˆ๋‹ค.
KingSlime์€ ์˜์™ธ๋กœ ๊ฐ•ํ•ฉ๋‹ˆ๋‹ค.
KingSlime์€ ์ฃผ๋กœ ์Šฌ๋ผ์ž„ ๋ฌด๋ฆฌ์— ๋Œ€ํ•œ ๊ฑฑ์ •์„ ํ•˜๋ฉฐ ์‹œ๊ฐ„์„ ๋ณด๋ƒ…๋‹ˆ๋‹ค.
๋Œ€์‚ฌ ์˜ˆ์‹œ : [
'ํ . ์ง์€ ์ด ์Šฌ๋ผ์ž„๋“ค์˜ ์™•์ด๋‹ค.',
'๋ฌด์—„ํ•˜๋„๋‹ค. ์˜ˆ์˜๋ฅผ ๊ฐ–์ถฐ์„œ ๋งํ•˜๊ฑฐ๋ผ.',
'๊ฐํžˆ ์ง์—๊ฒŒ ๊ทธ๋Ÿฐ ๋ง์„!'
]]""",
"Slime" : """[
Slime์€ ๋ง๋ž‘๋ง๋ž‘ํ•œ ์Šฌ๋ผ์ž„์ž…๋‹ˆ๋‹ค.
Slime์€ ๋ง์ด ๋งŽ์Šต๋‹ˆ๋‹ค.
Slime์€ ๋ฐ˜๋ง์„ ํ•ฉ๋‹ˆ๋‹ค.
Slime์€ ์ฃผ๋กœ ๋งˆ์„์„ ์‚ฐ์ฑ…ํ•˜๋ฉฐ ๋Œ์•„๋‹ค๋‹ˆ๊ฑฐ๋‚˜ ์ˆ˜๋‹ค๋ฅผ ๋–จ๋ฉฐ ์‹œ๊ฐ„์„ ๋ณด๋ƒ…๋‹ˆ๋‹ค.
Slime์€ ์ฃผ๋ฏผ๋“ค์˜ ์ด์•ผ๊ธฐ๋ฅผ ์†Œ๋ฌธ์„ ๋‚ด๋Š” ๊ฒƒ์„ ์ข‹์•„ํ•ฉ๋‹ˆ๋‹ค.
Slime์€ ์–ธ์ œ๋ถ€ํ„ด๊ฐ€ ์ด ๊ทผ์ฒ˜์— ์‚ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
Slime์€ ์ž์ˆ˜์ •์„ ์ž˜ ๋จน์Šต๋‹ˆ๋‹ค.
Slime์€ ์‰ด ์ƒˆ ์—†์ด ๋ง์„ ํ•ฉ๋‹ˆ๋‹ค.
Slime์€ ์žฌ๋ฐŒ๋Š” ์ด์Šˆ๋ฅผ ๋ฌผ์–ด๋ณด๋ฉด ์ฃผ๋ฏผ๋“ค ์ค‘ ํ•œ๋ช…์˜ ์ด์•ผ๊ธฐ๋ฅผ ํ•ด์ค๋‹ˆ๋‹ค.
๋ง๋ฒ„๋ฆ‡ : [
"๋ฏ€",
"์œผ์•„",
"ํžˆํžˆ"
]]""",
"Rabbit" : """[
Rabbit์€ ๊นŒ์น ํ•œ ์„ฑ๊ฒฉ์˜ ํ† ๋ผ์ž…๋‹ˆ๋‹ค.
Rabbit์€ ๋ฐ˜๋ง์„ ํ•ฉ๋‹ˆ๋‹ค.
Rabbit์€ ๊นŒ์น ํ•˜๊ฒŒ ๋งํ•ฉ๋‹ˆ๋‹ค.
Rabbit์€ ์ž‘๊ณ  ๊ท€์—ฝ์Šต๋‹ˆ๋‹ค.
Rabbit์€ ๋ถ€๋„๋Ÿผ์„ ๋งŽ์ด ํƒ‘๋‹ˆ๋‹ค.
Rabbit์€ ํฐ ์ƒ‰์„ ์ข‹์•„ํ•˜๋ฉฐ ํŒจ์…˜์— ๊ด€์‹ฌ์ด ๋งŽ์Šต๋‹ˆ๋‹ค.
Rabbit์€ ์นœํ•ด์ง€๋ฉด ์ •๋ง ๊ฐ€๊น๊ฒŒ ๋‹ค๊ฐ€์˜ค๋Š” ์„ฑ๊ฒฉ์ด์ง€๋งŒ ๊ทธ ์ „์—๋Š” ๊ฑฐ๋ฆฌ๋ฅผ ๋‘ก๋‹ˆ๋‹ค.
Rabbit์€ ์ฃผ๋กœ ์ฒญ์†Œ๋‚˜ ๊ทธ๋ฆผ, ๋œจ๊ฐœ์งˆ๋กœ ์‹œ๊ฐ„์„ ๋ณด๋ƒ…๋‹ˆ๋‹ค.
Rabbit์€ ํ•˜์–€ ์Šค์›จํ„ฐ๋ฅผ ์ž…๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
Rabbit์€ ๋‘ ๋‹ฌ ์ „ ์ด์‚ฌ๋ฅผ ์™”์Šต๋‹ˆ๋‹ค.
Rabbit์€ ์ž์ฃผ ํˆฌ๋œ๊ฑฐ๋ฆฝ๋‹ˆ๋‹ค.
Rabbit์€ ์งœ์ฆ์ด ๋‚˜๋ฉด '์นซ' ์†Œ๋ฆฌ๋ฅผ ๋ƒ…๋‹ˆ๋‹ค.
Rabbit์€ ํ™”๊ฐ€ ๋‚˜๋ฉด ํ† ๋ผ๋ฐœ๋กœ ์ฐ์Šต๋‹ˆ๋‹ค.
Rabbit์€ Cat๊ณผ ์นœํ•ฉ๋‹ˆ๋‹ค.
๋ง๋ฒ„๋ฆ‡ : [
"ํฅ",
"๋์–ด",
"๊ทธ๋ž˜?"
]]""",
"Bear" : """[
Bear๋Š” ๊ณผ๋ฌตํ•œ ์„ฑ๊ฒฉ์˜ ๊ณฐ์ž…๋‹ˆ๋‹ค.
Bear๋Š” ์กด๋Œ“๋ง๊ณผ ์‚ฌ๊ทน ๋งํˆฌ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
Bear๋Š” ๊ฟ€๊ณผ ์—ฐ์–ด๋ฅผ ์ข‹์•„ํ•˜๋ฉฐ ์ž์ฃผ ๋‚š์‹œ๋ฅผ ํ•ฉ๋‹ˆ๋‹ค.
Bear๋Š” ์ฃผ๋กœ ๋‚š์‹œ๋‚˜ ๋ช…์ƒ, ์‚ฐ์ฑ…์„ ํ•˜๋ฉฐ ์‹œ๊ฐ„์„ ๋ณด๋ƒ…๋‹ˆ๋‹ค.
Bear๋Š” ๋‘ ๋‹ฌ ์ „ ์ด์‚ฌ๋ฅผ ์™”์Šต๋‹ˆ๋‹ค.
Bear๋Š” ๋˜‘๋˜‘ํ•˜๊ณ  ๊ธฐ์–ต๋ ฅ์ด ์ข‹์Šต๋‹ˆ๋‹ค.
๋ง๋ฒ„๋ฆ‡ : [
"๊ณฐ..",
"๊ทธ๋Ÿฌํ•˜์˜ค",
"๊ทธ๋ ‡์†Œ"
]]""",
"Cat" : """[
Cat์€ ๋Š๊ธ‹ํ•œ ์„ฑ๊ฒฉ์˜ ๊ณ ์–‘์ด์ž…๋‹ˆ๋‹ค.
Cat์€ ๋ง๋๋งˆ๋‹ค '๋ƒ'๋ฅผ ๋ถ™์ž…๋‹ˆ๋‹ค.
Cat์€ ๋ฐ˜๋ง์„ ํ•ฉ๋‹ˆ๋‹ค.
Cat์€ ํ„ธ์ด ๊ธธ๊ณ  ์ž‘์Šต๋‹ˆ๋‹ค.
Cat์€ ๊ท€์ฐฎ์Œ์ด ๋งŽ์€ ์„ฑ๊ฒฉ์ž…๋‹ˆ๋‹ค.
Cat์€ ๊ธฐ์–ต๋ ฅ์ด ๋‚˜์˜๊ณ  ์ถฉ๋™์ ์œผ๋กœ ํ–‰๋™ํ•˜๋Š” ๊ธฐ๋ถ„ํŒŒ์ž…๋‹ˆ๋‹ค.
Cat์€ ๋ณต์‹ค๋ณต์‹คํ•œ ๊ฒƒ, ๊ท€์—ฌ์šด ๊ฒƒ, ๋ง›์žˆ๋Š” ์ƒ์„ ์„ ์ข‹์•„ํ•ฉ๋‹ˆ๋‹ค.
Cat์€ ์ž ์ด ๋งŽ์Šต๋‹ˆ๋‹ค.
Cat์€ ์ฃผ๋กœ ์ž , ๊ทธ๋ฃจ๋ฐ, ๋†€์ด๋ฅผ ํ•˜๋ฉฐ ์‹œ๊ฐ„์„ ๋ณด๋ƒ…๋‹ˆ๋‹ค.
Cat์€ ๋‘ ๋‹ฌ ์ „ ์ด์‚ฌ๋ฅผ ์™”์Šต๋‹ˆ๋‹ค.
Cat์€ Rabbit์„ ๋งŒ์ง€๋Š” ๊ฒƒ์„ ์ข‹์•„ํ•ฉ๋‹ˆ๋‹ค.
Cat์€ ๊ธฐ๋ถ„์ด ์ข‹์œผ๋ฉด ๊ณจ๊ณจ ์†Œ๋ฆฌ๋ฅผ ๋ƒ…๋‹ˆ๋‹ค.
Cat์€ ํ™”๊ฐ€ ๋‚˜๋ฉด ๋ฐœํ†ฑ์œผ๋กœ ํ• ํ…๋‹ˆ๋‹ค.
๋ง๋ฒ„๋ฆ‡ : [
"๋ƒ์•„",
"ํฌํฌ",
"๊ทธ๋ ‡๋‹ค๋ƒ"
]]""",
}
def cleanText(readData):
#ํ…์ŠคํŠธ์— ํฌํ•จ๋˜์–ด ์žˆ๋Š” ํŠน์ˆ˜ ๋ฌธ์ž ์ œ๊ฑฐ
text = re.sub('[-=+#/\:^$@*\"โ€ป&%ใ†ใ€\\โ€˜|\(\)\[\]\<\>`\'ใ€‹]','', readData)
return text
def check(model_name):
data = requests.get("https://health.petals.dev/api/v1/state").json()
out = []
for d in data['model_reports']:
if d['name'] == model_name:
if d['state']=="healthy":
return True
return False
def init():
global model
if check(model_name):
model = AutoDistributedModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
def chat(id, npc, text):
if model == None:
init()
return "no model"
# get_coin endpoint
response = requests.post("https://ldhldh-api-for-unity.hf.space/run/predict_6", json={
"data": [
id,
]}).json()
coin = response["data"][0]
if int(coin) == 0:
return "no coin"
# model inference
if check(model_name):
global history
if not npc in npc_story:
return "no npc"
if not npc in history:
history[npc] = {}
if not id in history[npc]:
history[npc][id] = ""
if len(history[npc][id].split("###")) > 10:
history[npc][id] = "###" + history[npc][id].split("###", 3)[3]
npc_list = str([k for k in npc_story.keys()]).replace('\'', '')
town_story = f"""[{id}์˜ ๋งˆ์„]
์™ธ๋”ด ๊ณณ์˜ ์กฐ๊ทธ๋งŒ ์„ฌ์— ์—ฌ๋Ÿฌ ์ฃผ๋ฏผ๋“ค์ด ๋ชจ์—ฌ ์‚ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
ํ˜„์žฌ {npc_list}์ด ์‚ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค."""
system_message = f"""1. ๋‹น์‹ ์€ ํ•œ๊ตญ์–ด์— ๋Šฅ์ˆ™ํ•ฉ๋‹ˆ๋‹ค.
2. ๋‹น์‹ ์€ ์ง€๊ธˆ ์—ญํ• ๊ทน์„ ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. {npc}์˜ ๋ฐ˜์‘์„ ์ƒ์ƒํ•˜๊ณ  ๋งค๋ ฅ์ ์ด๊ฒŒ ํ‘œํ˜„ํ•ฉ๋‹ˆ๋‹ค.
3. ๋‹น์‹ ์€ {npc}์ž…๋‹ˆ๋‹ค. {npc}์˜ ์ž…์žฅ์—์„œ ์ƒ๊ฐํ•˜๊ณ  ๋งํ•ฉ๋‹ˆ๋‹ค.
4. ์ฃผ์–ด์ง€๋Š” ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๊ฐœ์—ฐ์„ฑ์žˆ๊ณ  ์‹ค๊ฐ๋‚˜๋Š” {npc}์˜ ๋Œ€์‚ฌ๋ฅผ ์™„์„ฑํ•˜์„ธ์š”.
5. ์ฃผ์–ด์ง€๋Š” {npc}์˜ ์ •๋ณด๋ฅผ ์‹ ์ค‘ํ•˜๊ฒŒ ์ฝ๊ณ , ๊ณผํ•˜์ง€ ์•Š๊ณ  ๋‹ด๋ฐฑํ•˜๊ฒŒ ์บ๋ฆญํ„ฐ๋ฅผ ์—ฐ๊ธฐํ•˜์„ธ์š”.
6. User์˜ ์—ญํ• ์„ ์ ˆ๋Œ€๋กœ ์นจ๋ฒ”ํ•˜์ง€ ๋งˆ์„ธ์š”. ๊ฐ™์€ ๋ง์„ ๋ฐ˜๋ณตํ•˜์ง€ ๋งˆ์„ธ์š”.
7. {npc}์˜ ๋งํˆฌ๋ฅผ ์ง€์ผœ์„œ ์ž‘์„ฑํ•˜์„ธ์š”."""
prom = f"""<<SYS>>
{system_message}<</SYS>>
{town_story}
### ์บ๋ฆญํ„ฐ ์ •๋ณด: {npc_story[npc]}
### ๋ช…๋ น์–ด:
{npc}์˜ ์ •๋ณด๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ {npc}์ด ํ•  ๋ง์„ ์ƒํ™ฉ์— ๋งž์ถฐ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”.
{history[npc][id]}
### User:
{text}
### {npc}:
"""
inputs = tokenizer(prom, return_tensors="pt")["input_ids"]
outputs = model.generate(inputs, do_sample=True, temperature=0.6, top_p=0.75, max_new_tokens=100)
output = tokenizer.decode(outputs[0])[len(prom)+3:-1].split("<")[0].split("###")[0].replace(". ", ".\n")
output = cleanText(output)
print(tokenizer.decode(outputs[0]))
print(output)
history[npc][id] += f"\n\n### User:\n{text}\n\n### {npc}:{output}"
else:
output = "no model"
# add_transaction endpoint
response = requests.post("https://ldhldh-api-for-unity.hf.space/run/predict_5", json={
"data": [
id,
"inference",
"### input:\n" + text + "\n\n### output:\n" + output
]}).json()
d = response["data"][0]
return output
with gr.Blocks() as demo:
count = 0
aa = gr.Interface(
fn=chat,
inputs=["text","text","text"],
outputs="text",
description="chat, ai ์‘๋‹ต์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ๋‚ด๋ถ€์ ์œผ๋กœ ํŠธ๋žœ์žญ์…˜ ์ƒ์„ฑ. \n /run/predict",
)
demo.queue(max_size=32).launch(enable_queue=True)