Salt
commited on
Commit
·
a13affa
1
Parent(s):
61200c0
Update server.py
Browse files
server.py
CHANGED
@@ -1,101 +1,39 @@
|
|
1 |
-
from functools import wraps
|
2 |
from flask import (
|
3 |
Flask,
|
4 |
jsonify,
|
5 |
request,
|
6 |
render_template_string,
|
7 |
abort,
|
8 |
-
send_from_directory,
|
9 |
-
send_file,
|
10 |
)
|
11 |
from flask_cors import CORS
|
12 |
-
import markdown
|
13 |
-
import argparse
|
14 |
-
from transformers import AutoTokenizer, AutoProcessor, pipeline
|
15 |
-
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
16 |
-
from transformers import BlipForConditionalGeneration, GPT2Tokenizer
|
17 |
import unicodedata
|
18 |
-
import
|
19 |
import time
|
20 |
import os
|
21 |
import gc
|
22 |
-
from PIL import Image
|
23 |
import base64
|
24 |
from io import BytesIO
|
25 |
from random import randint
|
26 |
-
import webuiapi
|
27 |
import hashlib
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
from constants import *
|
29 |
from colorama import Fore, Style, init as colorama_init
|
30 |
|
31 |
colorama_init()
|
32 |
|
33 |
-
|
34 |
-
class SplitArgs(argparse.Action):
|
35 |
-
def __call__(self, parser, namespace, values, option_string=None):
|
36 |
-
setattr(
|
37 |
-
namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
|
38 |
-
)
|
39 |
-
|
40 |
-
|
41 |
-
# Script arguments
|
42 |
-
parser = argparse.ArgumentParser(
|
43 |
-
prog="TavernAI Extras", description="Web API for transformers models"
|
44 |
-
)
|
45 |
-
parser.add_argument(
|
46 |
-
"--port", type=int, help="Specify the port on which the application is hosted"
|
47 |
-
)
|
48 |
-
parser.add_argument(
|
49 |
-
"--listen", action="store_true", help="Host the app on the local network"
|
50 |
-
)
|
51 |
-
parser.add_argument(
|
52 |
-
"--share", action="store_true", help="Share the app on CloudFlare tunnel"
|
53 |
-
)
|
54 |
-
parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
|
55 |
-
parser.add_argument("--summarization-model", help="Load a custom summarization model")
|
56 |
parser.add_argument(
|
57 |
"--classification-model", help="Load a custom text classification model"
|
58 |
)
|
59 |
-
parser.add_argument("--captioning-model", help="Load a custom captioning model")
|
60 |
-
parser.add_argument(
|
61 |
-
"--keyphrase-model", help="Load a custom keyphrase extraction model"
|
62 |
-
)
|
63 |
-
parser.add_argument("--prompt-model", help="Load a custom prompt generation model")
|
64 |
-
parser.add_argument("--embedding-model", help="Load a custom text embedding model")
|
65 |
-
|
66 |
-
sd_group = parser.add_mutually_exclusive_group()
|
67 |
-
|
68 |
-
local_sd = sd_group.add_argument_group("sd-local")
|
69 |
-
local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
|
70 |
-
local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU")
|
71 |
-
|
72 |
-
remote_sd = sd_group.add_argument_group("sd-remote")
|
73 |
-
remote_sd.add_argument(
|
74 |
-
"--sd-remote", action="store_true", help="Use a remote backend for SD"
|
75 |
-
)
|
76 |
-
remote_sd.add_argument(
|
77 |
-
"--sd-remote-host", type=str, help="Specify the host of the remote SD backend"
|
78 |
-
)
|
79 |
-
remote_sd.add_argument(
|
80 |
-
"--sd-remote-port", type=int, help="Specify the port of the remote SD backend"
|
81 |
-
)
|
82 |
-
remote_sd.add_argument(
|
83 |
-
"--sd-remote-ssl", action="store_true", help="Use SSL for the remote SD backend"
|
84 |
-
)
|
85 |
-
remote_sd.add_argument(
|
86 |
-
"--sd-remote-auth",
|
87 |
-
type=str,
|
88 |
-
help="Specify the username:password for the remote SD backend (if required)",
|
89 |
-
)
|
90 |
-
|
91 |
-
parser.add_argument(
|
92 |
-
"--enable-modules",
|
93 |
-
action=SplitArgs,
|
94 |
-
default=[],
|
95 |
-
help="Override a list of enabled modules",
|
96 |
-
)
|
97 |
-
|
98 |
-
args = parser.parse_args()
|
99 |
|
100 |
port = 7860
|
101 |
host = "0.0.0.0"
|
@@ -111,31 +49,17 @@ classification_model = (
|
|
111 |
else DEFAULT_CLASSIFICATION_MODEL
|
112 |
)
|
113 |
|
114 |
-
modules = (
|
115 |
-
args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
|
116 |
-
)
|
117 |
|
118 |
-
|
119 |
-
print(
|
120 |
-
f"{Fore.RED}{Style.BRIGHT}You did not select any modules to run! Choose them by adding an --enable-modules option"
|
121 |
-
)
|
122 |
-
print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
|
123 |
|
124 |
-
|
125 |
-
device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu"
|
126 |
-
device = torch.device(device_string)
|
127 |
-
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
|
133 |
-
summarization_model, torch_dtype=torch_dtype
|
134 |
-
).to(device)
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
classification_pipe = pipeline(
|
139 |
"text-classification",
|
140 |
model=classification_model,
|
141 |
top_k=None,
|
@@ -143,40 +67,34 @@ if "classify" in modules:
|
|
143 |
torch_dtype=torch_dtype,
|
144 |
)
|
145 |
|
146 |
-
if "chromadb" in modules:
|
147 |
-
print("Initializing ChromaDB")
|
148 |
-
import chromadb
|
149 |
-
import posthog
|
150 |
-
from chromadb.config import Settings
|
151 |
-
from sentence_transformers import SentenceTransformer
|
152 |
|
153 |
-
# disable chromadb telemetry
|
154 |
-
posthog.capture = lambda *args, **kwargs: None
|
155 |
-
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
156 |
-
chromadb_embedder = SentenceTransformer(embedding_model)
|
157 |
-
chromadb_embed_fn = chromadb_embedder.encode
|
158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
# Flask init
|
161 |
app = Flask(__name__)
|
162 |
CORS(app) # allow cross-domain requests
|
163 |
app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
|
164 |
|
|
|
|
|
|
|
165 |
|
166 |
-
def
|
167 |
-
|
168 |
-
@wraps(fn)
|
169 |
-
def decorated_view(*args, **kwargs):
|
170 |
-
if name not in modules:
|
171 |
-
abort(403, "Module is disabled by config")
|
172 |
-
return fn(*args, **kwargs)
|
173 |
-
|
174 |
-
return decorated_view
|
175 |
-
|
176 |
-
return wrapper
|
177 |
-
|
178 |
|
179 |
-
# AI stuff
|
180 |
def classify_text(text: str) -> list:
|
181 |
output = classification_pipe(
|
182 |
text,
|
@@ -243,7 +161,6 @@ def after_request(response):
|
|
243 |
response.headers["X-Request-Duration"] = str(duration)
|
244 |
return response
|
245 |
|
246 |
-
|
247 |
@app.route("/", methods=["GET"])
|
248 |
def index():
|
249 |
with open("./README.md", "r", encoding="utf8") as f:
|
@@ -251,42 +168,104 @@ def index():
|
|
251 |
return render_template_string(markdown.markdown(content, extensions=["tables"]))
|
252 |
|
253 |
|
254 |
-
@app.route("/api/
|
255 |
-
def
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
|
|
269 |
)
|
270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
|
272 |
|
273 |
-
@app.route("/api/
|
274 |
-
|
275 |
-
def api_caption():
|
276 |
data = request.get_json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
-
|
279 |
-
|
|
|
|
|
|
|
280 |
|
281 |
-
|
282 |
-
image = image.convert("RGB")
|
283 |
-
image.thumbnail((512, 512))
|
284 |
-
caption = caption_image(image)
|
285 |
-
thumbnail = image_to_base64(image)
|
286 |
-
print("Caption:", caption, sep="\n")
|
287 |
-
gc.collect()
|
288 |
-
return jsonify({"caption": caption, "thumbnail": thumbnail})
|
289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
|
291 |
@app.route("/api/summarize", methods=["POST"])
|
292 |
@require_module("summarize")
|
@@ -308,8 +287,8 @@ def api_summarize():
|
|
308 |
return jsonify({"summary": summary})
|
309 |
|
310 |
|
|
|
311 |
@app.route("/api/classify", methods=["POST"])
|
312 |
-
@require_module("classify")
|
313 |
def api_classify():
|
314 |
data = request.get_json()
|
315 |
|
@@ -324,138 +303,10 @@ def api_classify():
|
|
324 |
|
325 |
|
326 |
@app.route("/api/classify/labels", methods=["GET"])
|
327 |
-
@require_module("classify")
|
328 |
def api_classify_labels():
|
329 |
classification = classify_text("")
|
330 |
labels = [x["label"] for x in classification]
|
331 |
return jsonify({"labels": labels})
|
332 |
|
333 |
|
334 |
-
|
335 |
-
@require_module("keywords")
|
336 |
-
def api_keywords():
|
337 |
-
data = request.get_json()
|
338 |
-
|
339 |
-
if "text" not in data or not isinstance(data["text"], str):
|
340 |
-
abort(400, '"text" is required')
|
341 |
-
|
342 |
-
print("Keywords input:", data["text"], sep="\n")
|
343 |
-
keywords = extract_keywords(data["text"])
|
344 |
-
print("Keywords output:", keywords, sep="\n")
|
345 |
-
return jsonify({"keywords": keywords})
|
346 |
-
|
347 |
-
|
348 |
-
@app.route("/api/prompt", methods=["POST"])
|
349 |
-
@require_module("prompt")
|
350 |
-
def api_prompt():
|
351 |
-
data = request.get_json()
|
352 |
-
|
353 |
-
if "text" not in data or not isinstance(data["text"], str):
|
354 |
-
abort(400, '"text" is required')
|
355 |
-
|
356 |
-
keywords = extract_keywords(data["text"])
|
357 |
-
|
358 |
-
if "name" in data and isinstance(data["name"], str):
|
359 |
-
keywords.insert(0, data["name"])
|
360 |
-
|
361 |
-
print("Prompt input:", data["text"], sep="\n")
|
362 |
-
prompts = generate_prompt(keywords)
|
363 |
-
print("Prompt output:", prompts, sep="\n")
|
364 |
-
return jsonify({"prompts": prompts})
|
365 |
-
|
366 |
-
@app.route("/api/modules", methods=["GET"])
|
367 |
-
def get_modules():
|
368 |
-
return jsonify({"modules": modules})
|
369 |
-
|
370 |
-
@app.route("/api/chromadb", methods=["POST"])
|
371 |
-
@require_module("chromadb")
|
372 |
-
def chromadb_add_messages():
|
373 |
-
data = request.get_json()
|
374 |
-
if "chat_id" not in data or not isinstance(data["chat_id"], str):
|
375 |
-
abort(400, '"chat_id" is required')
|
376 |
-
if "messages" not in data or not isinstance(data["messages"], list):
|
377 |
-
abort(400, '"messages" is required')
|
378 |
-
|
379 |
-
chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
|
380 |
-
collection = chromadb_client.get_or_create_collection(
|
381 |
-
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
382 |
-
)
|
383 |
-
|
384 |
-
documents = [m["content"] for m in data["messages"]]
|
385 |
-
ids = [m["id"] for m in data["messages"]]
|
386 |
-
metadatas = [
|
387 |
-
{"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
|
388 |
-
for m in data["messages"]
|
389 |
-
]
|
390 |
-
|
391 |
-
collection.upsert(
|
392 |
-
ids=ids,
|
393 |
-
documents=documents,
|
394 |
-
metadatas=metadatas,
|
395 |
-
)
|
396 |
-
|
397 |
-
return jsonify({"count": len(ids)})
|
398 |
-
|
399 |
-
|
400 |
-
@app.route("/api/chromadb/purge", methods=["POST"])
|
401 |
-
@require_module("chromadb")
|
402 |
-
def chromadb_purge():
|
403 |
-
data = request.get_json()
|
404 |
-
if "chat_id" not in data or not isinstance(data["chat_id"], str):
|
405 |
-
abort(400, '"chat_id" is required')
|
406 |
-
|
407 |
-
chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
|
408 |
-
collection = chromadb_client.get_or_create_collection(
|
409 |
-
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
410 |
-
)
|
411 |
-
|
412 |
-
deleted = collection.delete()
|
413 |
-
print("ChromaDB embeddings deleted", len(deleted))
|
414 |
-
return 'Ok', 200
|
415 |
-
|
416 |
-
|
417 |
-
@app.route("/api/chromadb/query", methods=["POST"])
|
418 |
-
@require_module("chromadb")
|
419 |
-
def chromadb_query():
|
420 |
-
data = request.get_json()
|
421 |
-
if "chat_id" not in data or not isinstance(data["chat_id"], str):
|
422 |
-
abort(400, '"chat_id" is required')
|
423 |
-
if "query" not in data or not isinstance(data["query"], str):
|
424 |
-
abort(400, '"query" is required')
|
425 |
-
|
426 |
-
if "n_results" not in data or not isinstance(data["n_results"], int):
|
427 |
-
n_results = 1
|
428 |
-
else:
|
429 |
-
n_results = data["n_results"]
|
430 |
-
|
431 |
-
chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
|
432 |
-
collection = chromadb_client.get_or_create_collection(
|
433 |
-
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
434 |
-
)
|
435 |
-
|
436 |
-
n_results = min(collection.count(), n_results)
|
437 |
-
query_result = collection.query(
|
438 |
-
query_texts=[data["query"]],
|
439 |
-
n_results=n_results,
|
440 |
-
)
|
441 |
-
|
442 |
-
documents = query_result["documents"][0]
|
443 |
-
ids = query_result["ids"][0]
|
444 |
-
metadatas = query_result["metadatas"][0]
|
445 |
-
distances = query_result["distances"][0]
|
446 |
-
|
447 |
-
messages = [
|
448 |
-
{
|
449 |
-
"id": ids[i],
|
450 |
-
"date": metadatas[i]["date"],
|
451 |
-
"role": metadatas[i]["role"],
|
452 |
-
"meta": metadatas[i]["meta"],
|
453 |
-
"content": documents[i],
|
454 |
-
"distance": distances[i],
|
455 |
-
}
|
456 |
-
for i in range(len(ids))
|
457 |
-
]
|
458 |
-
|
459 |
-
return jsonify(messages)
|
460 |
-
|
461 |
-
app.run(host=host, port=port)
|
|
|
|
|
1 |
from flask import (
|
2 |
Flask,
|
3 |
jsonify,
|
4 |
request,
|
5 |
render_template_string,
|
6 |
abort,
|
|
|
|
|
7 |
)
|
8 |
from flask_cors import CORS
|
|
|
|
|
|
|
|
|
|
|
9 |
import unicodedata
|
10 |
+
import markdown
|
11 |
import time
|
12 |
import os
|
13 |
import gc
|
|
|
14 |
import base64
|
15 |
from io import BytesIO
|
16 |
from random import randint
|
|
|
17 |
import hashlib
|
18 |
+
import chromadb
|
19 |
+
import posthog
|
20 |
+
from chromadb.config import Settings
|
21 |
+
from sentence_transformers import SentenceTransformer
|
22 |
+
from werkzeug.middleware.proxy_fix import ProxyFix
|
23 |
+
import argparse
|
24 |
+
from transformers import AutoTokenizer, AutoProcessor, pipeline
|
25 |
+
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
26 |
+
from transformers import BlipForConditionalGeneration, GPT2Tokenizer
|
27 |
+
from PIL import Image
|
28 |
+
import webuiapi
|
29 |
from constants import *
|
30 |
from colorama import Fore, Style, init as colorama_init
|
31 |
|
32 |
colorama_init()
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
parser.add_argument(
|
35 |
"--classification-model", help="Load a custom text classification model"
|
36 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
port = 7860
|
39 |
host = "0.0.0.0"
|
|
|
49 |
else DEFAULT_CLASSIFICATION_MODEL
|
50 |
)
|
51 |
|
|
|
|
|
|
|
52 |
|
53 |
+
embedding_model = 'sentence-transformers/all-mpnet-base-v2'
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
print("Initializing a text summarization model...")
|
|
|
|
|
|
|
56 |
|
57 |
+
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
|
58 |
+
summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
|
59 |
+
summarization_model, torch_dtype=torch_dtype).to(device)
|
|
|
|
|
|
|
60 |
|
61 |
+
print("Initializing a sentiment classification pipeline...")
|
62 |
+
classification_pipe = pipeline(
|
|
|
63 |
"text-classification",
|
64 |
model=classification_model,
|
65 |
top_k=None,
|
|
|
67 |
torch_dtype=torch_dtype,
|
68 |
)
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
+
print("Initializing ChromaDB")
|
73 |
+
|
74 |
+
device_string = "cpu"
|
75 |
+
device = torch.device(device_string)
|
76 |
+
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
# disable chromadb telemetry
|
81 |
+
posthog.capture = lambda *args, **kwargs: None
|
82 |
+
chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
|
83 |
+
chromadb_embedder = SentenceTransformer(embedding_model)
|
84 |
+
chromadb_embed_fn = chromadb_embedder.encode
|
85 |
|
86 |
# Flask init
|
87 |
app = Flask(__name__)
|
88 |
CORS(app) # allow cross-domain requests
|
89 |
app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
|
90 |
|
91 |
+
app.wsgi_app = ProxyFix(
|
92 |
+
app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1
|
93 |
+
)
|
94 |
|
95 |
+
def get_real_ip():
|
96 |
+
return request.remote_addr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
|
|
98 |
def classify_text(text: str) -> list:
|
99 |
output = classification_pipe(
|
100 |
text,
|
|
|
161 |
response.headers["X-Request-Duration"] = str(duration)
|
162 |
return response
|
163 |
|
|
|
164 |
@app.route("/", methods=["GET"])
|
165 |
def index():
|
166 |
with open("./README.md", "r", encoding="utf8") as f:
|
|
|
168 |
return render_template_string(markdown.markdown(content, extensions=["tables"]))
|
169 |
|
170 |
|
171 |
+
@app.route("/api/modules", methods=["GET"])
|
172 |
+
def get_modules():
|
173 |
+
return jsonify({"modules": ['chromadb']})
|
174 |
+
|
175 |
+
@app.route("/api/chromadb", methods=["POST"])
|
176 |
+
def chromadb_add_messages():
|
177 |
+
data = request.get_json()
|
178 |
+
if "chat_id" not in data or not isinstance(data["chat_id"], str):
|
179 |
+
abort(400, '"chat_id" is required')
|
180 |
+
if "messages" not in data or not isinstance(data["messages"], list):
|
181 |
+
abort(400, '"messages" is required')
|
182 |
+
|
183 |
+
ip = get_real_ip()
|
184 |
+
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
|
185 |
+
collection = chromadb_client.get_or_create_collection(
|
186 |
+
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
187 |
)
|
188 |
+
|
189 |
+
documents = [m["content"] for m in data["messages"]]
|
190 |
+
ids = [m["id"] for m in data["messages"]]
|
191 |
+
metadatas = [
|
192 |
+
{"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
|
193 |
+
for m in data["messages"]
|
194 |
+
]
|
195 |
+
|
196 |
+
if len(ids) > 0:
|
197 |
+
collection.upsert(
|
198 |
+
ids=ids,
|
199 |
+
documents=documents,
|
200 |
+
metadatas=metadatas,
|
201 |
+
)
|
202 |
+
|
203 |
+
return jsonify({"count": len(ids)})
|
204 |
|
205 |
|
206 |
+
@app.route("/api/chromadb/query", methods=["POST"])
|
207 |
+
def chromadb_query():
|
|
|
208 |
data = request.get_json()
|
209 |
+
if "chat_id" not in data or not isinstance(data["chat_id"], str):
|
210 |
+
abort(400, '"chat_id" is required')
|
211 |
+
if "query" not in data or not isinstance(data["query"], str):
|
212 |
+
abort(400, '"query" is required')
|
213 |
+
|
214 |
+
if "n_results" not in data or not isinstance(data["n_results"], int):
|
215 |
+
n_results = 1
|
216 |
+
else:
|
217 |
+
n_results = data["n_results"]
|
218 |
|
219 |
+
ip = get_real_ip()
|
220 |
+
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
|
221 |
+
collection = chromadb_client.get_or_create_collection(
|
222 |
+
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
223 |
+
)
|
224 |
|
225 |
+
n_results = min(collection.count(), n_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
+
messages = []
|
228 |
+
if n_results > 0:
|
229 |
+
query_result = collection.query(
|
230 |
+
query_texts=[data["query"]],
|
231 |
+
n_results=n_results,
|
232 |
+
)
|
233 |
+
|
234 |
+
documents = query_result["documents"][0]
|
235 |
+
ids = query_result["ids"][0]
|
236 |
+
metadatas = query_result["metadatas"][0]
|
237 |
+
distances = query_result["distances"][0]
|
238 |
+
|
239 |
+
messages = [
|
240 |
+
{
|
241 |
+
"id": ids[i],
|
242 |
+
"date": metadatas[i]["date"],
|
243 |
+
"role": metadatas[i]["role"],
|
244 |
+
"meta": metadatas[i]["meta"],
|
245 |
+
"content": documents[i],
|
246 |
+
"distance": distances[i],
|
247 |
+
}
|
248 |
+
for i in range(len(ids))
|
249 |
+
]
|
250 |
+
|
251 |
+
return jsonify(messages)
|
252 |
+
|
253 |
+
@app.route("/api/chromadb/purge", methods=["POST"])
|
254 |
+
def chromadb_purge():
|
255 |
+
data = request.get_json()
|
256 |
+
if "chat_id" not in data or not isinstance(data["chat_id"], str):
|
257 |
+
abort(400, '"chat_id" is required')
|
258 |
+
|
259 |
+
ip = get_real_ip()
|
260 |
+
chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
|
261 |
+
collection = chromadb_client.get_or_create_collection(
|
262 |
+
name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
|
263 |
+
)
|
264 |
+
|
265 |
+
deleted = collection.delete()
|
266 |
+
print("ChromaDB embeddings deleted", len(deleted))
|
267 |
+
|
268 |
+
return 'Ok', 200
|
269 |
|
270 |
@app.route("/api/summarize", methods=["POST"])
|
271 |
@require_module("summarize")
|
|
|
287 |
return jsonify({"summary": summary})
|
288 |
|
289 |
|
290 |
+
|
291 |
@app.route("/api/classify", methods=["POST"])
|
|
|
292 |
def api_classify():
|
293 |
data = request.get_json()
|
294 |
|
|
|
303 |
|
304 |
|
305 |
@app.route("/api/classify/labels", methods=["GET"])
|
|
|
306 |
def api_classify_labels():
|
307 |
classification = classify_text("")
|
308 |
labels = [x["label"] for x in classification]
|
309 |
return jsonify({"labels": labels})
|
310 |
|
311 |
|
312 |
+
app.run(host=host, port=port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|