|
from flask import ( |
|
Flask, |
|
jsonify, |
|
request, |
|
render_template_string, |
|
abort, |
|
) |
|
from flask_cors import CORS |
|
import unicodedata |
|
import markdown |
|
import time |
|
import os |
|
import gc |
|
import base64 |
|
from io import BytesIO |
|
from random import randint |
|
import hashlib |
|
import chromadb |
|
import posthog |
|
from chromadb.config import Settings |
|
from sentence_transformers import SentenceTransformer |
|
from werkzeug.middleware.proxy_fix import ProxyFix |
|
import argparse |
|
from transformers import AutoTokenizer, AutoProcessor, pipeline |
|
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM |
|
from transformers import BlipForConditionalGeneration, GPT2Tokenizer |
|
from PIL import Image |
|
import webuiapi |
|
from constants import * |
|
from colorama import Fore, Style, init as colorama_init |
|
|
|
colorama_init() |
|
|
|
parser.add_argument( |
|
"--classification-model", help="Load a custom text classification model" |
|
) |
|
|
|
port = 7860 |
|
host = "0.0.0.0" |
|
|
|
summarization_model = ( |
|
args.summarization_model |
|
if args.summarization_model |
|
else DEFAULT_SUMMARIZATION_MODEL |
|
) |
|
classification_model = ( |
|
args.classification_model |
|
if args.classification_model |
|
else DEFAULT_CLASSIFICATION_MODEL |
|
) |
|
|
|
|
|
embedding_model = 'sentence-transformers/all-mpnet-base-v2' |
|
|
|
print("Initializing a text summarization model...") |
|
|
|
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model) |
|
summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained( |
|
summarization_model, torch_dtype=torch_dtype).to(device) |
|
|
|
print("Initializing a sentiment classification pipeline...") |
|
classification_pipe = pipeline( |
|
"text-classification", |
|
model=classification_model, |
|
top_k=None, |
|
device=device, |
|
torch_dtype=torch_dtype, |
|
) |
|
|
|
|
|
|
|
print("Initializing ChromaDB") |
|
|
|
device_string = "cpu" |
|
device = torch.device(device_string) |
|
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16 |
|
|
|
|
|
|
|
|
|
posthog.capture = lambda *args, **kwargs: None |
|
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False)) |
|
chromadb_embedder = SentenceTransformer(embedding_model) |
|
chromadb_embed_fn = chromadb_embedder.encode |
|
|
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024 |
|
|
|
app.wsgi_app = ProxyFix( |
|
app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1 |
|
) |
|
|
|
def get_real_ip(): |
|
return request.remote_addr |
|
|
|
def classify_text(text: str) -> list: |
|
output = classification_pipe( |
|
text, |
|
truncation=True, |
|
max_length=classification_pipe.model.config.max_position_embeddings, |
|
)[0] |
|
return sorted(output, key=lambda x: x["score"], reverse=True) |
|
|
|
|
|
def summarize_chunks(text: str, params: dict) -> str: |
|
try: |
|
return summarize(text, params) |
|
except IndexError: |
|
print( |
|
"Sequence length too large for model, cutting text in half and calling again" |
|
) |
|
new_params = params.copy() |
|
new_params["max_length"] = new_params["max_length"] // 2 |
|
new_params["min_length"] = new_params["min_length"] // 2 |
|
return summarize_chunks( |
|
text[: (len(text) // 2)], new_params |
|
) + summarize_chunks(text[(len(text) // 2) :], new_params) |
|
|
|
|
|
def summarize(text: str, params: dict) -> str: |
|
|
|
inputs = summarization_tokenizer(text, return_tensors="pt").to(device) |
|
token_count = len(inputs[0]) |
|
|
|
bad_words_ids = [ |
|
summarization_tokenizer(bad_word, add_special_tokens=False).input_ids |
|
for bad_word in params["bad_words"] |
|
] |
|
summary_ids = summarization_transformer.generate( |
|
inputs["input_ids"], |
|
num_beams=2, |
|
max_new_tokens=max(token_count, int(params["max_length"])), |
|
min_new_tokens=min(token_count, int(params["min_length"])), |
|
repetition_penalty=float(params["repetition_penalty"]), |
|
temperature=float(params["temperature"]), |
|
length_penalty=float(params["length_penalty"]), |
|
bad_words_ids=bad_words_ids, |
|
) |
|
summary = summarization_tokenizer.batch_decode( |
|
summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True |
|
)[0] |
|
summary = normalize_string(summary) |
|
return summary |
|
|
|
|
|
def normalize_string(input: str) -> str: |
|
output = " ".join(unicodedata.normalize("NFKC", input).strip().split()) |
|
return output |
|
|
|
@app.before_request |
|
|
|
def before_request(): |
|
request.start_time = time.time() |
|
|
|
|
|
@app.after_request |
|
def after_request(response): |
|
duration = time.time() - request.start_time |
|
response.headers["X-Request-Duration"] = str(duration) |
|
return response |
|
|
|
@app.route("/", methods=["GET"]) |
|
def index(): |
|
with open("./README.md", "r", encoding="utf8") as f: |
|
content = f.read() |
|
return render_template_string(markdown.markdown(content, extensions=["tables"])) |
|
|
|
|
|
@app.route("/api/modules", methods=["GET"]) |
|
def get_modules(): |
|
return jsonify({"modules": ['chromadb']}) |
|
|
|
@app.route("/api/chromadb", methods=["POST"]) |
|
def chromadb_add_messages(): |
|
data = request.get_json() |
|
if "chat_id" not in data or not isinstance(data["chat_id"], str): |
|
abort(400, '"chat_id" is required') |
|
if "messages" not in data or not isinstance(data["messages"], list): |
|
abort(400, '"messages" is required') |
|
|
|
ip = get_real_ip() |
|
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() |
|
collection = chromadb_client.get_or_create_collection( |
|
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
|
) |
|
|
|
documents = [m["content"] for m in data["messages"]] |
|
ids = [m["id"] for m in data["messages"]] |
|
metadatas = [ |
|
{"role": m["role"], "date": m["date"], "meta": m.get("meta", "")} |
|
for m in data["messages"] |
|
] |
|
|
|
if len(ids) > 0: |
|
collection.upsert( |
|
ids=ids, |
|
documents=documents, |
|
metadatas=metadatas, |
|
) |
|
|
|
return jsonify({"count": len(ids)}) |
|
|
|
|
|
@app.route("/api/chromadb/query", methods=["POST"]) |
|
def chromadb_query(): |
|
data = request.get_json() |
|
if "chat_id" not in data or not isinstance(data["chat_id"], str): |
|
abort(400, '"chat_id" is required') |
|
if "query" not in data or not isinstance(data["query"], str): |
|
abort(400, '"query" is required') |
|
|
|
if "n_results" not in data or not isinstance(data["n_results"], int): |
|
n_results = 1 |
|
else: |
|
n_results = data["n_results"] |
|
|
|
ip = get_real_ip() |
|
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() |
|
collection = chromadb_client.get_or_create_collection( |
|
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
|
) |
|
|
|
n_results = min(collection.count(), n_results) |
|
|
|
messages = [] |
|
if n_results > 0: |
|
query_result = collection.query( |
|
query_texts=[data["query"]], |
|
n_results=n_results, |
|
) |
|
|
|
documents = query_result["documents"][0] |
|
ids = query_result["ids"][0] |
|
metadatas = query_result["metadatas"][0] |
|
distances = query_result["distances"][0] |
|
|
|
messages = [ |
|
{ |
|
"id": ids[i], |
|
"date": metadatas[i]["date"], |
|
"role": metadatas[i]["role"], |
|
"meta": metadatas[i]["meta"], |
|
"content": documents[i], |
|
"distance": distances[i], |
|
} |
|
for i in range(len(ids)) |
|
] |
|
|
|
return jsonify(messages) |
|
|
|
@app.route("/api/chromadb/purge", methods=["POST"]) |
|
def chromadb_purge(): |
|
data = request.get_json() |
|
if "chat_id" not in data or not isinstance(data["chat_id"], str): |
|
abort(400, '"chat_id" is required') |
|
|
|
ip = get_real_ip() |
|
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest() |
|
collection = chromadb_client.get_or_create_collection( |
|
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn |
|
) |
|
|
|
deleted = collection.delete() |
|
print("ChromaDB embeddings deleted", len(deleted)) |
|
|
|
return 'Ok', 200 |
|
|
|
@app.route("/api/summarize", methods=["POST"]) |
|
@require_module("summarize") |
|
def api_summarize(): |
|
data = request.get_json() |
|
|
|
if "text" not in data or not isinstance(data["text"], str): |
|
abort(400, '"text" is required') |
|
|
|
params = DEFAULT_SUMMARIZE_PARAMS.copy() |
|
|
|
if "params" in data and isinstance(data["params"], dict): |
|
params.update(data["params"]) |
|
|
|
print("Summary input:", data["text"], sep="\n") |
|
summary = summarize_chunks(data["text"], params) |
|
print("Summary output:", summary, sep="\n") |
|
gc.collect() |
|
return jsonify({"summary": summary}) |
|
|
|
|
|
|
|
@app.route("/api/classify", methods=["POST"]) |
|
def api_classify(): |
|
data = request.get_json() |
|
|
|
if "text" not in data or not isinstance(data["text"], str): |
|
abort(400, '"text" is required') |
|
|
|
print("Classification input:", data["text"], sep="\n") |
|
classification = classify_text(data["text"]) |
|
print("Classification output:", classification, sep="\n") |
|
gc.collect() |
|
return jsonify({"classification": classification}) |
|
|
|
|
|
@app.route("/api/classify/labels", methods=["GET"]) |
|
def api_classify_labels(): |
|
classification = classify_text("") |
|
labels = [x["label"] for x in classification] |
|
return jsonify({"labels": labels}) |
|
|
|
|
|
app.run(host=host, port=port) |