from functools import wraps | |
from flask import ( | |
Flask, | |
jsonify, | |
request, | |
render_template_string, | |
abort, | |
send_from_directory, | |
send_file, | |
) | |
from flask_cors import CORS | |
import unicodedata | |
import argparse | |
import markdown | |
import time | |
import os | |
import gc | |
import base64 | |
from io import BytesIO | |
from random import randint | |
import hashlib | |
import chromadb | |
import posthog | |
import torch | |
from chromadb.config import Settings | |
from sentence_transformers import SentenceTransformer | |
from werkzeug.middleware.proxy_fix import ProxyFix | |
from transformers import AutoTokenizer, AutoProcessor, pipeline | |
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM | |
from transformers import BlipForConditionalGeneration, GPT2Tokenizer | |
from PIL import Image | |
import webuiapi | |
from colorama import Fore, Style, init as colorama_init | |
colorama_init() | |
port = 7860 | |
host = "0.0.0.0" | |
class SplitArgs(argparse.Action): | |
def __call__(self, parser, namespace, values, option_string=None): | |
setattr( | |
namespace, self.dest, values.replace('"', "").replace("'", "").split(",") | |
) | |
# Script arguments | |
parser = argparse.ArgumentParser( | |
prog="TavernAI Extras", description="Web API for transformers models" | |
) | |
parser.add_argument("--summarization-model", help="Load a custom summarization model") | |
parser.add_argument("--classification-model", help="Load a custom text classification model") | |
parser.add_argument( | |
"--enable-modules", | |
action=SplitArgs, | |
default=[], | |
help="Override a list of enabled modules", | |
) | |
args = parser.parse_args() | |
summarization_model = ( | |
args.summarization_model | |
if args.summarization_model | |
else "Qiliang/bart-large-cnn-samsum-ChatGPT_v3" | |
) | |
classification_model = ( | |
args.classification_model | |
if args.classification_model | |
else "nateraw/bert-base-uncased-emotion" | |
) | |
device_string = "cpu" | |
device = torch.device(device_string) | |
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16 | |
embedding_model = 'sentence-transformers/all-mpnet-base-v2' | |
print("Initializing a text summarization model...") | |
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model) | |
summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained( | |
summarization_model, torch_dtype=torch_dtype).to(device) | |
print("Initializing a sentiment classification pipeline...") | |
classification_pipe = pipeline( | |
"text-classification", | |
model=classification_model, | |
top_k=None, | |
device=device, | |
torch_dtype=torch_dtype, | |
) | |
print("Initializing ChromaDB") | |
# disable chromadb telemetry | |
posthog.capture = lambda *args, **kwargs: None | |
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False)) | |
chromadb_embedder = SentenceTransformer(embedding_model) | |
chromadb_embed_fn = chromadb_embedder.encode | |
# Flask init | |
app = Flask(__name__) | |
CORS(app) # allow cross-domain requests | |
app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024 | |
app.wsgi_app = ProxyFix( | |
app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1 | |
) | |
def get_real_ip(): | |
return request.remote_addr | |
def classify_text(text: str) -> list: | |
output = classification_pipe( | |
text, | |
truncation=True, | |
max_length=classification_pipe.model.config.max_position_embeddings, | |
)[0] | |
return sorted(output, key=lambda x: x["score"], reverse=True) | |
def summarize_chunks(text: str, params: dict) -> str: | |
try: | |
return summarize(text, params) | |
except IndexError: | |
print( | |
"Sequence length too large for model, cutting text in half and calling again" | |
) | |
new_params = params.copy() | |
new_params["max_length"] = new_params["max_length"] // 2 | |
new_params["min_length"] = new_params["min_length"] // 2 | |
return summarize_chunks( | |
text[: (len(text) // 2)], new_params | |
) + summarize_chunks(text[(len(text) // 2) :], new_params) | |
def summarize(text: str, params: dict) -> str: | |
# Tokenize input | |
inputs = summarization_tokenizer(text, return_tensors="pt").to(device) | |
token_count = len(inputs[0]) | |
bad_words_ids = [ | |
summarization_tokenizer(bad_word, add_special_tokens=False).input_ids | |
for bad_word in params["bad_words"] | |
] | |
summary_ids = summarization_transformer.generate( | |
inputs["input_ids"], | |
num_beams=2, | |
max_new_tokens=max(token_count, int(params["max_length"])), | |
min_new_tokens=min(token_count, int(params["min_length"])), | |
repetition_penalty=float(params["repetition_penalty"]), | |
temperature=float(params["temperature"]), | |
length_penalty=float(params["length_penalty"]), | |
bad_words_ids=bad_words_ids, | |
) | |
summary = summarization_tokenizer.batch_decode( | |
summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
)[0] | |
summary = normalize_string(summary) | |
return summary | |
def normalize_string(input: str) -> str: | |
output = " ".join(unicodedata.normalize("NFKC", input).strip().split()) | |
return output | |
# Request time measuring | |
def before_request(): | |
request.start_time = time.time() | |
def after_request(response): | |
duration = time.time() - request.start_time | |
response.headers["X-Request-Duration"] = str(duration) | |
return response | |
def index(): | |
with open("./README.md", "r", encoding="utf8") as f: | |
content = f.read() | |
return render_template_string(markdown.markdown(content, extensions=["tables"])) | |
def get_modules(): | |
return jsonify({"modules": ['chromadb','summarize','classify']}) | |
def chromadb_add_messages(): | |
data = request.get_json() | |
if "chat_id" not in data or not isinstance(data["chat_id"], str): | |
abort(400, '"chat_id" is required') | |
if "messages" not in data or not isinstance(data["messages"], list): | |
abort(400, '"messages" is required') | |
ip = get_real_ip() | |
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() | |
collection = chromadb_client.get_or_create_collection( | |
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn | |
) | |
documents = [m["content"] for m in data["messages"]] | |
ids = [m["id"] for m in data["messages"]] | |
metadatas = [ | |
{"role": m["role"], "date": m["date"], "meta": m.get("meta", "")} | |
for m in data["messages"] | |
] | |
if len(ids) > 0: | |
collection.upsert( | |
ids=ids, | |
documents=documents, | |
metadatas=metadatas, | |
) | |
return jsonify({"count": len(ids)}) | |
def chromadb_query(): | |
data = request.get_json() | |
if "chat_id" not in data or not isinstance(data["chat_id"], str): | |
abort(400, '"chat_id" is required') | |
if "query" not in data or not isinstance(data["query"], str): | |
abort(400, '"query" is required') | |
if "n_results" not in data or not isinstance(data["n_results"], int): | |
n_results = 1 | |
else: | |
n_results = data["n_results"] | |
ip = get_real_ip() | |
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() | |
collection = chromadb_client.get_or_create_collection( | |
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn | |
) | |
n_results = min(collection.count(), n_results) | |
messages = [] | |
if n_results > 0: | |
query_result = collection.query( | |
query_texts=[data["query"]], | |
n_results=n_results, | |
) | |
documents = query_result["documents"][0] | |
ids = query_result["ids"][0] | |
metadatas = query_result["metadatas"][0] | |
distances = query_result["distances"][0] | |
messages = [ | |
{ | |
"id": ids[i], | |
"date": metadatas[i]["date"], | |
"role": metadatas[i]["role"], | |
"meta": metadatas[i]["meta"], | |
"content": documents[i], | |
"distance": distances[i], | |
} | |
for i in range(len(ids)) | |
] | |
return jsonify(messages) | |
def chromadb_purge(): | |
data = request.get_json() | |
if "chat_id" not in data or not isinstance(data["chat_id"], str): | |
abort(400, '"chat_id" is required') | |
ip = get_real_ip() | |
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() | |
collection = chromadb_client.get_or_create_collection( | |
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn | |
) | |
deleted = collection.delete() | |
print("ChromaDB embeddings deleted", len(deleted)) | |
return 'Ok', 200 | |
def api_summarize(): | |
data = request.get_json() | |
if "text" not in data or not isinstance(data["text"], str): | |
abort(400, '"text" is required') | |
params = { | |
"temperature": 1.0, | |
"repetition_penalty": 1.0, | |
"max_length": 500, | |
"min_length": 200, | |
"length_penalty": 1.5, | |
"bad_words": [ | |
"\n", | |
'"', | |
"*", | |
"[", | |
"]", | |
"{", | |
"}", | |
":", | |
"(", | |
")", | |
"<", | |
">", | |
"Â", | |
"The text ends", | |
"The story ends", | |
"The text is", | |
"The story is", | |
], | |
} | |
if "params" in data and isinstance(data["params"], dict): | |
params.update(data["params"]) | |
print("Summary input:", data["text"], sep="\n") | |
summary = summarize_chunks(data["text"], params) | |
print("Summary output:", summary, sep="\n") | |
gc.collect() | |
return jsonify({"summary": summary}) | |
def api_classify(): | |
data = request.get_json() | |
if "text" not in data or not isinstance(data["text"], str): | |
abort(400, '"text" is required') | |
print("Classification input:", data["text"], sep="\n") | |
classification = classify_text(data["text"]) | |
print("Classification output:", classification, sep="\n") | |
gc.collect() | |
return jsonify({"classification": classification}) | |
def api_classify_labels(): | |
classification = classify_text("") | |
labels = [x["label"] for x in classification] | |
return jsonify({"labels": labels}) | |
app.run(host=host, port=port) |