File size: 6,280 Bytes
7def60a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python3
import grpc
from concurrent import futures
import time
import backend_pb2
import backend_pb2_grpc
import argparse
import signal
import sys
import os, glob

from pathlib import Path
import torch
import torch.nn.functional as F
from torch import version as torch_version

from source.tokenizer import ExLlamaTokenizer
from source.generator import ExLlamaGenerator
from source.model import ExLlama, ExLlamaCache, ExLlamaConfig

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))

# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
    def generate(self,prompt, max_new_tokens):
        self.generator.end_beam_search()

        # Tokenizing the input
        ids = self.generator.tokenizer.encode(prompt)

        self.generator.gen_begin_reuse(ids)
        initial_len = self.generator.sequence[0].shape[0]
        has_leading_space = False
        decoded_text = ''
        for i in range(max_new_tokens):
            token = self.generator.gen_single_token()
            if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
                has_leading_space = True

            decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
            if has_leading_space:
                decoded_text = ' ' + decoded_text

            if token.item() == self.generator.tokenizer.eos_token_id:
                break
        return decoded_text
    def Health(self, request, context):
        return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
    def LoadModel(self, request, context):
        try:
            # https://github.com/turboderp/exllama/blob/master/example_cfg.py
            model_directory = request.ModelFile

            # Locate files we need within that directory
            tokenizer_path = os.path.join(model_directory, "tokenizer.model")
            model_config_path = os.path.join(model_directory, "config.json")
            st_pattern = os.path.join(model_directory, "*.safetensors")
            model_path = glob.glob(st_pattern)[0]

            # Create config, model, tokenizer and generator

            config = ExLlamaConfig(model_config_path)               # create config from config.json
            config.model_path = model_path                          # supply path to model weights file
            if (request.ContextSize):
                config.max_seq_len = request.ContextSize            # override max sequence length
                config.max_attention_size = request.ContextSize**2  # Should be set to context_size^2. 
                # https://github.com/turboderp/exllama/issues/220#issuecomment-1720324163

            # Set Rope scaling.
            if (request.RopeFreqScale):
                # Alpha value for Rope scaling. 
                # Higher value increases context but adds perplexity.
                # alpha_value and compress_pos_emb are mutually exclusive.
                # https://github.com/turboderp/exllama/issues/115
                config.alpha_value = request.RopeFreqScale
                config.calculate_rotary_embedding_base()

            model = ExLlama(config)                                 # create ExLlama instance and load the weights
            tokenizer = ExLlamaTokenizer(tokenizer_path)            # create tokenizer from tokenizer model file

            cache = ExLlamaCache(model, batch_size = 2)             # create cache for inference
            generator = ExLlamaGenerator(model, tokenizer, cache)   # create generator

            self.generator= generator
            self.model = model
            self.tokenizer = tokenizer
            self.cache = cache
        except Exception as err:
            return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
        return backend_pb2.Result(message="Model loaded successfully", success=True)

    def Predict(self, request, context):
        penalty = 1.15
        if request.Penalty != 0.0:
            penalty = request.Penalty
        self.generator.settings.token_repetition_penalty_max = penalty
        self.generator.settings.temperature = request.Temperature
        self.generator.settings.top_k = request.TopK
        self.generator.settings.top_p = request.TopP

        tokens = 512
        if request.Tokens != 0:
            tokens = request.Tokens

        if self.cache.batch_size == 1:
            del self.cache
            self.cache = ExLlamaCache(self.model, batch_size=2)
            self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)

        t = self.generate(request.Prompt, tokens)

        # Remove prompt from response if present
        if request.Prompt in t:
            t = t.replace(request.Prompt, "")

        return backend_pb2.Result(message=bytes(t, encoding='utf-8'))

    def PredictStream(self, request, context):
        # Implement PredictStream RPC
        #for reply in some_data_generator():
        #    yield reply
        # Not implemented yet
        return self.Predict(request, context)


def serve(address):
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
    backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
    server.add_insecure_port(address)
    server.start()
    print("Server started. Listening on: " + address, file=sys.stderr)

    # Define the signal handler function
    def signal_handler(sig, frame):
        print("Received termination signal. Shutting down...")
        server.stop(0)
        sys.exit(0)

    # Set the signal handlers for SIGINT and SIGTERM
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        server.stop(0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the gRPC server.")
    parser.add_argument(
        "--addr", default="localhost:50051", help="The address to bind the server to."
    )
    args = parser.parse_args()

    serve(args.addr)