File size: 4,405 Bytes
2673459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread

from modules import shared
from modules.text_generation import encode, generate_reply

params = {
    'port': 5000,
}


class Handler(BaseHTTPRequestHandler):
    def do_GET(self):
        if self.path == '/api/v1/model':
            self.send_response(200)
            self.end_headers()
            response = json.dumps({
                'result': shared.model_name
            })

            self.wfile.write(response.encode('utf-8'))
        else:
            self.send_error(404)

    def do_POST(self):
        content_length = int(self.headers['Content-Length'])
        body = json.loads(self.rfile.read(content_length).decode('utf-8'))

        if self.path == '/api/v1/generate':
            self.send_response(200)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()

            prompt = body['prompt']
            prompt_lines = [k.strip() for k in prompt.split('\n')]

            max_context = body.get('max_context_length', 2048)

            while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
                prompt_lines.pop(0)

            prompt = '\n'.join(prompt_lines)
            generate_params = {
                'max_new_tokens': int(body.get('max_length', 200)),
                'do_sample': bool(body.get('do_sample', True)),
                'temperature': float(body.get('temperature', 0.5)),
                'top_p': float(body.get('top_p', 1)),
                'typical_p': float(body.get('typical', 1)),
                'repetition_penalty': float(body.get('rep_pen', 1.1)),
                'encoder_repetition_penalty': 1,
                'top_k': int(body.get('top_k', 0)),
                'min_length': int(body.get('min_length', 0)),
                'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
                'num_beams': int(body.get('num_beams', 1)),
                'penalty_alpha': float(body.get('penalty_alpha', 0)),
                'length_penalty': float(body.get('length_penalty', 1)),
                'early_stopping': bool(body.get('early_stopping', False)),
                'seed': int(body.get('seed', -1)),
                'add_bos_token': int(body.get('add_bos_token', True)),
                'custom_stopping_strings': body.get('custom_stopping_strings', []),
                'truncation_length': int(body.get('truncation_length', 2048)),
                'ban_eos_token': bool(body.get('ban_eos_token', False)),
                'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
            }

            generator = generate_reply(
                prompt,
                generate_params,
            )

            answer = ''
            for a in generator:
                if isinstance(a, str):
                    answer = a
                else:
                    answer = a[0]

            response = json.dumps({
                'results': [{
                    'text': answer[len(prompt):]
                }]
            })
            self.wfile.write(response.encode('utf-8'))
        elif self.path == '/api/v1/token-count':
            # Not compatible with KoboldAI api
            self.send_response(200)
            self.send_header('Content-Type', 'application/json')
            self.end_headers()

            tokens = encode(body['prompt'])[0]
            response = json.dumps({
                'results': [{
                    'tokens': len(tokens)
                }]
            })
            self.wfile.write(response.encode('utf-8'))
        else:
            self.send_error(404)


def run_server():
    server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
    server = ThreadingHTTPServer(server_addr, Handler)
    if shared.args.share:
        try:
            from flask_cloudflared import _run_cloudflared
            public_url = _run_cloudflared(params['port'], params['port'] + 1)
            print(f'Starting KoboldAI compatible api at {public_url}/api')
        except ImportError:
            print('You should install flask_cloudflared manually')
    else:
        print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
    server.serve_forever()


def setup():
    Thread(target=run_server, daemon=True).start()