from functools import wraps from flask import ( Flask, jsonify, request, render_template_string, abort, send_from_directory, send_file, ) from flask_cors import CORS import unicodedata import argparse import markdown import time import os import gc import base64 from io import BytesIO from random import randint import hashlib import chromadb import posthog from chromadb.config import Settings from sentence_transformers import SentenceTransformer from werkzeug.middleware.proxy_fix import ProxyFix from transformers import AutoTokenizer, AutoProcessor, pipeline from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM from transformers import BlipForConditionalGeneration, GPT2Tokenizer from PIL import Image import webuiapi from constants import * from colorama import Fore, Style, init as colorama_init colorama_init() port = 7860 host = "0.0.0.0" class SplitArgs(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): setattr( namespace, self.dest, values.replace('"', "").replace("'", "").split(",") ) parser = argparse.ArgumentParser( prog="TavernAI Extras", description="Web API for transformers models" ) parser.add_argument("--summarization-model", help="Load a custom summarization model") parser.add_argument("--classification-model", help="Load a custom text classification model") args = parser.parse_args() summarization_model = ( args.summarization_model if args.summarization_model else DEFAULT_SUMMARIZATION_MODEL ) classification_model = ( args.classification_model if args.classification_model else DEFAULT_CLASSIFICATION_MODEL ) embedding_model = 'sentence-transformers/all-mpnet-base-v2' print("Initializing a text summarization model...") summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model) summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained( summarization_model, torch_dtype=torch_dtype).to(device) print("Initializing a sentiment classification pipeline...") classification_pipe = pipeline( "text-classification", model=classification_model, top_k=None, device=device, torch_dtype=torch_dtype, ) print("Initializing ChromaDB") device_string = "cpu" device = torch.device(device_string) torch_dtype = torch.float32 if device_string == "cpu" else torch.float16 # disable chromadb telemetry posthog.capture = lambda *args, **kwargs: None chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False)) chromadb_embedder = SentenceTransformer(embedding_model) chromadb_embed_fn = chromadb_embedder.encode # Flask init app = Flask(__name__) CORS(app) # allow cross-domain requests app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024 app.wsgi_app = ProxyFix( app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1 ) def get_real_ip(): return request.remote_addr def classify_text(text: str) -> list: output = classification_pipe( text, truncation=True, max_length=classification_pipe.model.config.max_position_embeddings, )[0] return sorted(output, key=lambda x: x["score"], reverse=True) def summarize_chunks(text: str, params: dict) -> str: try: return summarize(text, params) except IndexError: print( "Sequence length too large for model, cutting text in half and calling again" ) new_params = params.copy() new_params["max_length"] = new_params["max_length"] // 2 new_params["min_length"] = new_params["min_length"] // 2 return summarize_chunks( text[: (len(text) // 2)], new_params ) + summarize_chunks(text[(len(text) // 2) :], new_params) def summarize(text: str, params: dict) -> str: # Tokenize input inputs = summarization_tokenizer(text, return_tensors="pt").to(device) token_count = len(inputs[0]) bad_words_ids = [ summarization_tokenizer(bad_word, add_special_tokens=False).input_ids for bad_word in params["bad_words"] ] summary_ids = summarization_transformer.generate( inputs["input_ids"], num_beams=2, max_new_tokens=max(token_count, int(params["max_length"])), min_new_tokens=min(token_count, int(params["min_length"])), repetition_penalty=float(params["repetition_penalty"]), temperature=float(params["temperature"]), length_penalty=float(params["length_penalty"]), bad_words_ids=bad_words_ids, ) summary = summarization_tokenizer.batch_decode( summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True )[0] summary = normalize_string(summary) return summary def normalize_string(input: str) -> str: output = " ".join(unicodedata.normalize("NFKC", input).strip().split()) return output @app.before_request # Request time measuring def before_request(): request.start_time = time.time() @app.after_request def after_request(response): duration = time.time() - request.start_time response.headers["X-Request-Duration"] = str(duration) return response @app.route("/", methods=["GET"]) def index(): with open("./README.md", "r", encoding="utf8") as f: content = f.read() return render_template_string(markdown.markdown(content, extensions=["tables"])) @app.route("/api/modules", methods=["GET"]) def get_modules(): return jsonify({"modules": ['chromadb']}) @app.route("/api/chromadb", methods=["POST"]) def chromadb_add_messages(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') if "messages" not in data or not isinstance(data["messages"], list): abort(400, '"messages" is required') ip = get_real_ip() chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() collection = chromadb_client.get_or_create_collection( name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn ) documents = [m["content"] for m in data["messages"]] ids = [m["id"] for m in data["messages"]] metadatas = [ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")} for m in data["messages"] ] if len(ids) > 0: collection.upsert( ids=ids, documents=documents, metadatas=metadatas, ) return jsonify({"count": len(ids)}) @app.route("/api/chromadb/query", methods=["POST"]) def chromadb_query(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') if "query" not in data or not isinstance(data["query"], str): abort(400, '"query" is required') if "n_results" not in data or not isinstance(data["n_results"], int): n_results = 1 else: n_results = data["n_results"] ip = get_real_ip() chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() collection = chromadb_client.get_or_create_collection( name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn ) n_results = min(collection.count(), n_results) messages = [] if n_results > 0: query_result = collection.query( query_texts=[data["query"]], n_results=n_results, ) documents = query_result["documents"][0] ids = query_result["ids"][0] metadatas = query_result["metadatas"][0] distances = query_result["distances"][0] messages = [ { "id": ids[i], "date": metadatas[i]["date"], "role": metadatas[i]["role"], "meta": metadatas[i]["meta"], "content": documents[i], "distance": distances[i], } for i in range(len(ids)) ] return jsonify(messages) @app.route("/api/chromadb/purge", methods=["POST"]) def chromadb_purge(): data = request.get_json() if "chat_id" not in data or not isinstance(data["chat_id"], str): abort(400, '"chat_id" is required') ip = get_real_ip() chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() collection = chromadb_client.get_or_create_collection( name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn ) deleted = collection.delete() print("ChromaDB embeddings deleted", len(deleted)) return 'Ok', 200 @app.route("/api/summarize", methods=["POST"]) def api_summarize(): data = request.get_json() if "text" not in data or not isinstance(data["text"], str): abort(400, '"text" is required') params = DEFAULT_SUMMARIZE_PARAMS.copy() if "params" in data and isinstance(data["params"], dict): params.update(data["params"]) print("Summary input:", data["text"], sep="\n") summary = summarize_chunks(data["text"], params) print("Summary output:", summary, sep="\n") gc.collect() return jsonify({"summary": summary}) @app.route("/api/classify", methods=["POST"]) def api_classify(): data = request.get_json() if "text" not in data or not isinstance(data["text"], str): abort(400, '"text" is required') print("Classification input:", data["text"], sep="\n") classification = classify_text(data["text"]) print("Classification output:", classification, sep="\n") gc.collect() return jsonify({"classification": classification}) @app.route("/api/classify/labels", methods=["GET"]) def api_classify_labels(): classification = classify_text("") labels = [x["label"] for x in classification] return jsonify({"labels": labels}) app.run(host=host, port=port)