File size: 7,087 Bytes
36d1bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
from flask import Flask, request, Response
import logging
from llama_cpp import Llama
import threading
from huggingface_hub import snapshot_download#, Repository
import huggingface_hub
import gc
import os.path
import xml.etree.ElementTree as ET
from apscheduler.schedulers.background import BackgroundScheduler
from datetime import datetime, timedelta
from llm_backend import LlmBackend
import json
llm = LlmBackend()
_lock = threading.Lock()
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT') or "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык."
CONTEXT_SIZE = os.environ.get('CONTEXT_SIZE') or 500
ENABLE_GPU = os.environ.get('ENABLE_GPU') or False
GPU_LAYERS = os.environ.get('GPU_LAYERS') or 0
N_GQA = os.environ.get('N_GQA') or None #must be set to 8 for 70b models
CHAT_FORMAT = os.environ.get('CHAT_FORMAT') or 'llama-2'
# Create a lock object
lock = threading.Lock()
app = Flask(__name__)
# Configure Flask logging
app.logger.setLevel(logging.DEBUG)
# Variable to store the last request time
last_request_time = datetime.now()
# Initialize the model when the application starts
#model_path = "../models/model-q4_K.gguf" # Replace with the actual model path
#model_name = "model/ggml-model-q4_K.gguf"
#repo_name = "IlyaGusev/saiga2_13b_gguf"
#model_name = "model-q4_K.gguf"
#epo_name = "IlyaGusev/saiga2_70b_gguf"
#model_name = "ggml-model-q4_1.gguf"
repo_name = "IlyaGusev/saiga2_7b_gguf"
model_name = "model-q4_K.gguf"
local_dir = '.'
if os.path.isdir('/data'):
app.logger.info('Persistent storage enabled')
model = None
MODEL_PATH = snapshot_download(repo_id=repo_name, allow_patterns=model_name) + '/' + model_name
app.logger.info('Model path: ' + MODEL_PATH)
DATASET_REPO_URL = "https://huggingface.co/datasets/muryshev/saiga-chat"
DATA_FILENAME = "data-saiga-cuda-release.xml"
DATA_FILE = os.path.join("dataset", DATA_FILENAME)
HF_TOKEN = os.environ.get("HF_TOKEN")
app.logger.info("hfh: "+huggingface_hub.__version__)
# repo = Repository(
# local_dir="dataset", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
# )
# def log(req: str = '', resp: str = ''):
# if req or resp:
# element = ET.Element("row", {"time": str(datetime.now()) })
# req_element = ET.SubElement(element, "request")
# req_element.text = req
# resp_element = ET.SubElement(element, "response")
# resp_element.text = resp
# with open(DATA_FILE, "ab+") as xml_file:
# xml_file.write(ET.tostring(element, encoding="utf-8"))
# commit_url = repo.push_to_hub()
# app.logger.info(commit_url)
def generate_tokens(model, generator):
global stop_generation
app.logger.info('generate_tokens started')
with lock:
try:
for token in generator:
if token == model.token_eos() or stop_generation:
stop_generation = False
app.logger.info('End generating')
yield b'' # End of chunk
break
token_str = model.detokenize([token])#.decode("utf-8", errors="ignore")
yield token_str
except Exception as e:
app.logger.info('generator exception')
app.logger.info(e)
yield b'' # End of chunk
@app.route('/change_context_size', methods=['GET'])
def handler_change_context_size():
global stop_generation, model
stop_generation = True
new_size = int(request.args.get('size', CONTEXT_SIZE))
init_model(new_size, ENABLE_GPU, GPU_LAYERS)
return Response('Size changed', content_type='text/plain')
@app.route('/stop_generation', methods=['GET'])
def handler_stop_generation():
global stop_generation
stop_generation = True
return Response('Stopped', content_type='text/plain')
@app.route('/', methods=['GET', 'PUT', 'DELETE', 'PATCH'])
def generate_unknown_response():
app.logger.info('unknown method: '+request.method)
try:
request_payload = request.get_json()
app.logger.info('payload: '+request.get_json())
except Exception as e:
app.logger.info('payload empty')
return Response('What do you want?', content_type='text/plain')
response_tokens = bytearray()
def generate_and_log_tokens(user_request, generator):
global response_tokens, last_request_time
for token in llm.generate_tokens(generator):
if token == b'': # or (max_new_tokens is not None and i >= max_new_tokens):
last_request_time = datetime.now()
# log(json.dumps(user_request), response_tokens.decode("utf-8", errors="ignore"))
response_tokens = bytearray()
break
response_tokens.extend(token)
yield token
@app.route('/', methods=['POST'])
def generate_response():
app.logger.info('generate_response')
with _lock:
if not llm.is_model_loaded():
app.logger.info('model loading')
init_model()
data = request.get_json()
app.logger.info(data)
messages = data.get("messages", [])
preprompt = data.get("preprompt", "")
parameters = data.get("parameters", {})
# Extract parameters from the request
p = {
'temperature': parameters.get("temperature", 0.01),
'truncate': parameters.get("truncate", 1000),
'max_new_tokens': parameters.get("max_new_tokens", 1024),
'top_p': parameters.get("top_p", 0.85),
'repetition_penalty': parameters.get("repetition_penalty", 1.2),
'top_k': parameters.get("top_k", 30),
'return_full_text': parameters.get("return_full_text", False)
}
generator = llm.create_chat_generator_for_saiga(messages=messages, parameters=p)
app.logger.info('Generator created')
# Use Response to stream tokens
return Response(generate_and_log_tokens(user_request='1', generator=generator), content_type='text/plain', status=200, direct_passthrough=True)
def init_model():
llm.load_model(model_path=MODEL_PATH, context_size=CONTEXT_SIZE, enable_gpu=ENABLE_GPU, gpu_layer_number=GPU_LAYERS, n_gqa=N_GQA)
# Function to check if no requests were made in the last 5 minutes
def check_last_request_time():
global last_request_time
current_time = datetime.now()
if (current_time - last_request_time).total_seconds() > 300: # 5 minutes in seconds
# Perform the action (e.g., set a variable)
llm.unload_model()
app.logger.info(f"Model unloaded at {current_time}")
else:
app.logger.info(f"No action needed at {current_time}")
if __name__ == "__main__":
scheduler = BackgroundScheduler()
scheduler.add_job(check_last_request_time, trigger='interval', minutes=1)
scheduler.start()
init_model()
app.run(host="0.0.0.0", port=7860, debug=True, threaded=True) |