|
from flask import Flask, request, Response |
|
import logging |
|
from llama_cpp import Llama |
|
import threading |
|
from huggingface_hub import snapshot_download |
|
import huggingface_hub |
|
import gc |
|
import os.path |
|
import xml.etree.ElementTree as ET |
|
from apscheduler.schedulers.background import BackgroundScheduler |
|
from datetime import datetime, timedelta |
|
from llm_backend import LlmBackend |
|
import json |
|
|
|
llm = LlmBackend() |
|
_lock = threading.Lock() |
|
|
|
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT') or "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык." |
|
|
|
CONTEXT_SIZE = os.environ.get('CONTEXT_SIZE') or 500 |
|
ENABLE_GPU = os.environ.get('ENABLE_GPU') or False |
|
GPU_LAYERS = os.environ.get('GPU_LAYERS') or 0 |
|
N_GQA = os.environ.get('N_GQA') or None |
|
CHAT_FORMAT = os.environ.get('CHAT_FORMAT') or 'llama-2' |
|
|
|
|
|
lock = threading.Lock() |
|
|
|
app = Flask(__name__) |
|
|
|
app.logger.setLevel(logging.DEBUG) |
|
|
|
|
|
last_request_time = datetime.now() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
repo_name = "IlyaGusev/saiga2_7b_gguf" |
|
model_name = "model-q4_K.gguf" |
|
local_dir = '.' |
|
|
|
if os.path.isdir('/data'): |
|
app.logger.info('Persistent storage enabled') |
|
|
|
model = None |
|
|
|
MODEL_PATH = snapshot_download(repo_id=repo_name, allow_patterns=model_name) + '/' + model_name |
|
app.logger.info('Model path: ' + MODEL_PATH) |
|
|
|
DATASET_REPO_URL = "https://huggingface.co/datasets/muryshev/saiga-chat" |
|
DATA_FILENAME = "data-saiga-cuda-release.xml" |
|
DATA_FILE = os.path.join("dataset", DATA_FILENAME) |
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
app.logger.info("hfh: "+huggingface_hub.__version__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_tokens(model, generator): |
|
global stop_generation |
|
app.logger.info('generate_tokens started') |
|
with lock: |
|
try: |
|
for token in generator: |
|
if token == model.token_eos() or stop_generation: |
|
stop_generation = False |
|
app.logger.info('End generating') |
|
yield b'' |
|
break |
|
|
|
token_str = model.detokenize([token]) |
|
yield token_str |
|
except Exception as e: |
|
app.logger.info('generator exception') |
|
app.logger.info(e) |
|
yield b'' |
|
|
|
@app.route('/change_context_size', methods=['GET']) |
|
def handler_change_context_size(): |
|
global stop_generation, model |
|
stop_generation = True |
|
|
|
new_size = int(request.args.get('size', CONTEXT_SIZE)) |
|
init_model(new_size, ENABLE_GPU, GPU_LAYERS) |
|
|
|
return Response('Size changed', content_type='text/plain') |
|
|
|
@app.route('/stop_generation', methods=['GET']) |
|
def handler_stop_generation(): |
|
global stop_generation |
|
stop_generation = True |
|
return Response('Stopped', content_type='text/plain') |
|
|
|
@app.route('/', methods=['GET', 'PUT', 'DELETE', 'PATCH']) |
|
def generate_unknown_response(): |
|
app.logger.info('unknown method: '+request.method) |
|
try: |
|
request_payload = request.get_json() |
|
app.logger.info('payload: '+request.get_json()) |
|
except Exception as e: |
|
app.logger.info('payload empty') |
|
|
|
return Response('What do you want?', content_type='text/plain') |
|
|
|
response_tokens = bytearray() |
|
def generate_and_log_tokens(user_request, generator): |
|
global response_tokens, last_request_time |
|
for token in llm.generate_tokens(generator): |
|
if token == b'': |
|
last_request_time = datetime.now() |
|
|
|
response_tokens = bytearray() |
|
break |
|
response_tokens.extend(token) |
|
yield token |
|
|
|
@app.route('/', methods=['POST']) |
|
def generate_response(): |
|
|
|
app.logger.info('generate_response') |
|
with _lock: |
|
if not llm.is_model_loaded(): |
|
app.logger.info('model loading') |
|
init_model() |
|
|
|
data = request.get_json() |
|
app.logger.info(data) |
|
messages = data.get("messages", []) |
|
preprompt = data.get("preprompt", "") |
|
parameters = data.get("parameters", {}) |
|
|
|
|
|
p = { |
|
'temperature': parameters.get("temperature", 0.01), |
|
'truncate': parameters.get("truncate", 1000), |
|
'max_new_tokens': parameters.get("max_new_tokens", 1024), |
|
'top_p': parameters.get("top_p", 0.85), |
|
'repetition_penalty': parameters.get("repetition_penalty", 1.2), |
|
'top_k': parameters.get("top_k", 30), |
|
'return_full_text': parameters.get("return_full_text", False) |
|
} |
|
|
|
generator = llm.create_chat_generator_for_saiga(messages=messages, parameters=p) |
|
app.logger.info('Generator created') |
|
|
|
|
|
|
|
|
|
|
|
return Response(generate_and_log_tokens(user_request='1', generator=generator), content_type='text/plain', status=200, direct_passthrough=True) |
|
|
|
def init_model(): |
|
llm.load_model(model_path=MODEL_PATH, context_size=CONTEXT_SIZE, enable_gpu=ENABLE_GPU, gpu_layer_number=GPU_LAYERS, n_gqa=N_GQA) |
|
|
|
|
|
def check_last_request_time(): |
|
global last_request_time |
|
current_time = datetime.now() |
|
if (current_time - last_request_time).total_seconds() > 300: |
|
|
|
llm.unload_model() |
|
app.logger.info(f"Model unloaded at {current_time}") |
|
else: |
|
app.logger.info(f"No action needed at {current_time}") |
|
|
|
|
|
if __name__ == "__main__": |
|
scheduler = BackgroundScheduler() |
|
scheduler.add_job(check_last_request_time, trigger='interval', minutes=1) |
|
scheduler.start() |
|
|
|
init_model() |
|
|
|
app.run(host="0.0.0.0", port=7860, debug=True, threaded=True) |