File size: 4,439 Bytes
e48ab6b
 
 
 
 
2c77c32
8cccf6a
5c54d1b
669a4c0
e48ab6b
24fbd15
 
e48ab6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fce119
50be456
 
 
 
3e48437
61e57ac
8fefd32
6fce119
 
 
 
 
 
8fefd32
 
 
 
 
 
 
 
 
bfae98b
68b221f
bfae98b
 
 
8fefd32
ea6b9f5
61e57ac
 
 
7b0437a
 
 
 
 
 
4c9e80d
7b0437a
96906d9
 
 
3e48437
5ec78b2
ea6b9f5
6fce119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea6b9f5
6fce119
 
 
ea6b9f5
6fce119
 
ea6b9f5
6fce119
ea6b9f5
 
8fefd32
6fce119
 
 
a2f0814
8fefd32
 
96906d9
7b0437a
 
 
 
 
 
02c9e3b
7b0437a
 
4c9e80d
7b0437a
96906d9
 
 
e48ab6b
 
 
 
f0e04ff
e48ab6b
24fbd15
1418034
e48ab6b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from threading import Thread
import gradio as gr
import inspect
from gradio import routes
from typing import List, Type
from petals import AutoDistributedModelForCausalLM
from transformers import AutoTokenizer
import requests, os, re, asyncio, json

loop = asyncio.get_event_loop()

# init code
def get_types(cls_set: List[Type], component: str):
    docset = []
    types = []
    if component == "input":
        for cls in cls_set:
            doc = inspect.getdoc(cls)
            doc_lines = doc.split("\n")
            docset.append(doc_lines[1].split(":")[-1])
            types.append(doc_lines[1].split(")")[0].split("(")[-1])
    else:
        for cls in cls_set:
            doc = inspect.getdoc(cls)
            doc_lines = doc.split("\n")
            docset.append(doc_lines[-1].split(":")[-1])
            types.append(doc_lines[-1].split(")")[0].split("(")[-1])
    return docset, types
routes.get_types = get_types

# App code

model_name = "petals-team/StableBeluga2"

#daekeun-ml/Llama-2-ko-instruct-13B
#quantumaikr/llama-2-70b-fb16-korean
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = None

history = {
    "":{

    }
}

def check(model_name):
    data = requests.get("https://health.petals.dev/api/v1/state").json()
    out = []
    for d in data['model_reports']:
        if d['name'] == model_name:
            if d['state']=="healthy":
                return True
    return False

def init():
    global model
    if check(model_name):
        model = AutoDistributedModelForCausalLM.from_pretrained(model_name)


def chat(id, npc, text):
    if model == None:
        init()
        return "no model"
    # get_coin endpoint
    response = requests.post("https://ldhldh-api-for-unity.hf.space/run/predict_6", json={
      "data": [
        id,
    ]}).json()
    
    coin = response["data"][0]
    if int(coin) == 0:
        return "no coin"
    
    # model inference

    if check(model_name):
        
        global history
        if not npc in npc_story:
            return "no npc"
    
        if not npc in history:
            history[npc] = {}
        if not id in history[npc]:
            history[npc][id] = ""
        if len(history[npc][id].split("###")) > 10:
            history[npc][id] = "###" + history[npc][id].split("###", 3)[3]
        npc_list = str([k for k in npc_story.keys()]).replace('\'', '')
        town_story = f"""[{id}์˜ ๋งˆ์„]
์™ธ๋”ด ๊ณณ์˜ ์กฐ๊ทธ๋งŒ ์„ฌ์— ์—ฌ๋Ÿฌ ์ฃผ๋ฏผ๋“ค์ด ๋ชจ์—ฌ ์‚ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

ํ˜„์žฌ {npc_list}์ด ์‚ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค."""

    system_message = f"""1. ๋‹น์‹ ์€ ํ•œ๊ตญ์–ด์— ๋Šฅ์ˆ™ํ•ฉ๋‹ˆ๋‹ค.
2. ๋‹น์‹ ์€ ์ง€๊ธˆ ์—ญํ• ๊ทน์„ ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. {npc}์˜ ๋ฐ˜์‘์„ ์ƒ์ƒํ•˜๊ณ  ๋งค๋ ฅ์ ์ด๊ฒŒ ํ‘œํ˜„ํ•ฉ๋‹ˆ๋‹ค.
3. ๋‹น์‹ ์€ {npc}์ž…๋‹ˆ๋‹ค. {npc}์˜ ์ž…์žฅ์—์„œ ์ƒ๊ฐํ•˜๊ณ  ๋งํ•ฉ๋‹ˆ๋‹ค.
4. ์ฃผ์–ด์ง€๋Š” ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๊ฐœ์—ฐ์„ฑ์žˆ๊ณ  ์‹ค๊ฐ๋‚˜๋Š” {npc}์˜ ๋Œ€์‚ฌ๋ฅผ ์™„์„ฑํ•˜์„ธ์š”.
5. ์ฃผ์–ด์ง€๋Š” {npc}์˜ ์ •๋ณด๋ฅผ ์‹ ์ค‘ํ•˜๊ฒŒ ์ฝ๊ณ , ๊ณผํ•˜์ง€ ์•Š๊ณ  ๋‹ด๋ฐฑํ•˜๊ฒŒ ์บ๋ฆญํ„ฐ๋ฅผ ์—ฐ๊ธฐํ•˜์„ธ์š”.
6. User์˜ ์—ญํ• ์„ ์ ˆ๋Œ€๋กœ ์นจ๋ฒ”ํ•˜์ง€ ๋งˆ์„ธ์š”. ๊ฐ™์€ ๋ง์„ ๋ฐ˜๋ณตํ•˜์ง€ ๋งˆ์„ธ์š”.
7. {npc}์˜ ๋งํˆฌ๋ฅผ ์ง€์ผœ์„œ ์ž‘์„ฑํ•˜์„ธ์š”."""

        prom = f"""<<SYS>>
{system_message}<</SYS>>

{town_story}

### ์บ๋ฆญํ„ฐ ์ •๋ณด: {npc_story[npc]}

### ๋ช…๋ น์–ด:
{npc}์˜ ์ •๋ณด๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ {npc}์ด ํ•  ๋ง์„ ์ƒํ™ฉ์— ๋งž์ถฐ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”.
{history[npc][id]}

### User:
{text}

### {npc}:
"""

        inputs = tokenizer(prom, return_tensors="pt")["input_ids"]
        outputs = model.generate(inputs, do_sample=True, temperature=0.6, top_p=0.75, max_new_tokens=100)
        output = tokenizer.decode(outputs[0])[len(prom)+3:-1].split("<")[0].split("###")[0].replace(". ", ".\n")
        print(outputs)
        print(output)
    else:
        output = "no model"


    # add_transaction endpoint
    response = requests.post("https://ldhldh-api-for-unity.hf.space/run/predict_5", json={
      "data": [
        id, 
        "inference", 
        "### input:\n" + prompt + "\n\n### output:\n" + output
    ]}).json()
    
    d = response["data"][0]
    
    return output


with gr.Blocks() as demo:
    count = 0
    aa = gr.Interface(
      fn=chat,
      inputs=["text","text","text"],
      outputs="text",
      description="chat, ai ์‘๋‹ต์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ๋‚ด๋ถ€์ ์œผ๋กœ ํŠธ๋žœ์žญ์…˜ ์ƒ์„ฑ. \n /run/predict",
    )
    
    demo.queue(max_size=32).launch(enable_queue=True)