Salt commited on
Commit
a13affa
·
1 Parent(s): 61200c0

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +130 -279
server.py CHANGED
@@ -1,101 +1,39 @@
1
- from functools import wraps
2
  from flask import (
3
  Flask,
4
  jsonify,
5
  request,
6
  render_template_string,
7
  abort,
8
- send_from_directory,
9
- send_file,
10
  )
11
  from flask_cors import CORS
12
- import markdown
13
- import argparse
14
- from transformers import AutoTokenizer, AutoProcessor, pipeline
15
- from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
16
- from transformers import BlipForConditionalGeneration, GPT2Tokenizer
17
  import unicodedata
18
- import torch
19
  import time
20
  import os
21
  import gc
22
- from PIL import Image
23
  import base64
24
  from io import BytesIO
25
  from random import randint
26
- import webuiapi
27
  import hashlib
 
 
 
 
 
 
 
 
 
 
 
28
  from constants import *
29
  from colorama import Fore, Style, init as colorama_init
30
 
31
  colorama_init()
32
 
33
-
34
- class SplitArgs(argparse.Action):
35
- def __call__(self, parser, namespace, values, option_string=None):
36
- setattr(
37
- namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
38
- )
39
-
40
-
41
- # Script arguments
42
- parser = argparse.ArgumentParser(
43
- prog="TavernAI Extras", description="Web API for transformers models"
44
- )
45
- parser.add_argument(
46
- "--port", type=int, help="Specify the port on which the application is hosted"
47
- )
48
- parser.add_argument(
49
- "--listen", action="store_true", help="Host the app on the local network"
50
- )
51
- parser.add_argument(
52
- "--share", action="store_true", help="Share the app on CloudFlare tunnel"
53
- )
54
- parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
55
- parser.add_argument("--summarization-model", help="Load a custom summarization model")
56
  parser.add_argument(
57
  "--classification-model", help="Load a custom text classification model"
58
  )
59
- parser.add_argument("--captioning-model", help="Load a custom captioning model")
60
- parser.add_argument(
61
- "--keyphrase-model", help="Load a custom keyphrase extraction model"
62
- )
63
- parser.add_argument("--prompt-model", help="Load a custom prompt generation model")
64
- parser.add_argument("--embedding-model", help="Load a custom text embedding model")
65
-
66
- sd_group = parser.add_mutually_exclusive_group()
67
-
68
- local_sd = sd_group.add_argument_group("sd-local")
69
- local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
70
- local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU")
71
-
72
- remote_sd = sd_group.add_argument_group("sd-remote")
73
- remote_sd.add_argument(
74
- "--sd-remote", action="store_true", help="Use a remote backend for SD"
75
- )
76
- remote_sd.add_argument(
77
- "--sd-remote-host", type=str, help="Specify the host of the remote SD backend"
78
- )
79
- remote_sd.add_argument(
80
- "--sd-remote-port", type=int, help="Specify the port of the remote SD backend"
81
- )
82
- remote_sd.add_argument(
83
- "--sd-remote-ssl", action="store_true", help="Use SSL for the remote SD backend"
84
- )
85
- remote_sd.add_argument(
86
- "--sd-remote-auth",
87
- type=str,
88
- help="Specify the username:password for the remote SD backend (if required)",
89
- )
90
-
91
- parser.add_argument(
92
- "--enable-modules",
93
- action=SplitArgs,
94
- default=[],
95
- help="Override a list of enabled modules",
96
- )
97
-
98
- args = parser.parse_args()
99
 
100
  port = 7860
101
  host = "0.0.0.0"
@@ -111,31 +49,17 @@ classification_model = (
111
  else DEFAULT_CLASSIFICATION_MODEL
112
  )
113
 
114
- modules = (
115
- args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
116
- )
117
 
118
- if len(modules) == 0:
119
- print(
120
- f"{Fore.RED}{Style.BRIGHT}You did not select any modules to run! Choose them by adding an --enable-modules option"
121
- )
122
- print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
123
 
124
- # Models init
125
- device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu"
126
- device = torch.device(device_string)
127
- torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
128
 
129
- if "summarize" in modules:
130
- print("Initializing a text summarization model...")
131
- summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
132
- summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
133
- summarization_model, torch_dtype=torch_dtype
134
- ).to(device)
135
 
136
- if "classify" in modules:
137
- print("Initializing a sentiment classification pipeline...")
138
- classification_pipe = pipeline(
139
  "text-classification",
140
  model=classification_model,
141
  top_k=None,
@@ -143,40 +67,34 @@ if "classify" in modules:
143
  torch_dtype=torch_dtype,
144
  )
145
 
146
- if "chromadb" in modules:
147
- print("Initializing ChromaDB")
148
- import chromadb
149
- import posthog
150
- from chromadb.config import Settings
151
- from sentence_transformers import SentenceTransformer
152
 
153
- # disable chromadb telemetry
154
- posthog.capture = lambda *args, **kwargs: None
155
- chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
156
- chromadb_embedder = SentenceTransformer(embedding_model)
157
- chromadb_embed_fn = chromadb_embedder.encode
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  # Flask init
161
  app = Flask(__name__)
162
  CORS(app) # allow cross-domain requests
163
  app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
164
 
 
 
 
165
 
166
- def require_module(name):
167
- def wrapper(fn):
168
- @wraps(fn)
169
- def decorated_view(*args, **kwargs):
170
- if name not in modules:
171
- abort(403, "Module is disabled by config")
172
- return fn(*args, **kwargs)
173
-
174
- return decorated_view
175
-
176
- return wrapper
177
-
178
 
179
- # AI stuff
180
  def classify_text(text: str) -> list:
181
  output = classification_pipe(
182
  text,
@@ -243,7 +161,6 @@ def after_request(response):
243
  response.headers["X-Request-Duration"] = str(duration)
244
  return response
245
 
246
-
247
  @app.route("/", methods=["GET"])
248
  def index():
249
  with open("./README.md", "r", encoding="utf8") as f:
@@ -251,42 +168,104 @@ def index():
251
  return render_template_string(markdown.markdown(content, extensions=["tables"]))
252
 
253
 
254
- @app.route("/api/extensions", methods=["GET"])
255
- def get_extensions():
256
- extensions = dict(
257
- {
258
- "extensions": [
259
- {
260
- "name": "not-supported",
261
- "metadata": {
262
- "display_name": """<span style="white-space:break-spaces;">Extensions serving using Extensions API is no longer supported. Please update the mod from: <a href="https://github.com/Cohee1207/SillyTavern">https://github.com/Cohee1207/SillyTavern</a></span>""",
263
- "requires": [],
264
- "assets": [],
265
- },
266
- }
267
- ]
268
- }
 
269
  )
270
- return jsonify(extensions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
 
273
- @app.route("/api/caption", methods=["POST"])
274
- @require_module("caption")
275
- def api_caption():
276
  data = request.get_json()
 
 
 
 
 
 
 
 
 
277
 
278
- if "image" not in data or not isinstance(data["image"], str):
279
- abort(400, '"image" is required')
 
 
 
280
 
281
- image = Image.open(BytesIO(base64.b64decode(data["image"])))
282
- image = image.convert("RGB")
283
- image.thumbnail((512, 512))
284
- caption = caption_image(image)
285
- thumbnail = image_to_base64(image)
286
- print("Caption:", caption, sep="\n")
287
- gc.collect()
288
- return jsonify({"caption": caption, "thumbnail": thumbnail})
289
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  @app.route("/api/summarize", methods=["POST"])
292
  @require_module("summarize")
@@ -308,8 +287,8 @@ def api_summarize():
308
  return jsonify({"summary": summary})
309
 
310
 
 
311
  @app.route("/api/classify", methods=["POST"])
312
- @require_module("classify")
313
  def api_classify():
314
  data = request.get_json()
315
 
@@ -324,138 +303,10 @@ def api_classify():
324
 
325
 
326
  @app.route("/api/classify/labels", methods=["GET"])
327
- @require_module("classify")
328
  def api_classify_labels():
329
  classification = classify_text("")
330
  labels = [x["label"] for x in classification]
331
  return jsonify({"labels": labels})
332
 
333
 
334
- @app.route("/api/keywords", methods=["POST"])
335
- @require_module("keywords")
336
- def api_keywords():
337
- data = request.get_json()
338
-
339
- if "text" not in data or not isinstance(data["text"], str):
340
- abort(400, '"text" is required')
341
-
342
- print("Keywords input:", data["text"], sep="\n")
343
- keywords = extract_keywords(data["text"])
344
- print("Keywords output:", keywords, sep="\n")
345
- return jsonify({"keywords": keywords})
346
-
347
-
348
- @app.route("/api/prompt", methods=["POST"])
349
- @require_module("prompt")
350
- def api_prompt():
351
- data = request.get_json()
352
-
353
- if "text" not in data or not isinstance(data["text"], str):
354
- abort(400, '"text" is required')
355
-
356
- keywords = extract_keywords(data["text"])
357
-
358
- if "name" in data and isinstance(data["name"], str):
359
- keywords.insert(0, data["name"])
360
-
361
- print("Prompt input:", data["text"], sep="\n")
362
- prompts = generate_prompt(keywords)
363
- print("Prompt output:", prompts, sep="\n")
364
- return jsonify({"prompts": prompts})
365
-
366
- @app.route("/api/modules", methods=["GET"])
367
- def get_modules():
368
- return jsonify({"modules": modules})
369
-
370
- @app.route("/api/chromadb", methods=["POST"])
371
- @require_module("chromadb")
372
- def chromadb_add_messages():
373
- data = request.get_json()
374
- if "chat_id" not in data or not isinstance(data["chat_id"], str):
375
- abort(400, '"chat_id" is required')
376
- if "messages" not in data or not isinstance(data["messages"], list):
377
- abort(400, '"messages" is required')
378
-
379
- chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
380
- collection = chromadb_client.get_or_create_collection(
381
- name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
382
- )
383
-
384
- documents = [m["content"] for m in data["messages"]]
385
- ids = [m["id"] for m in data["messages"]]
386
- metadatas = [
387
- {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
388
- for m in data["messages"]
389
- ]
390
-
391
- collection.upsert(
392
- ids=ids,
393
- documents=documents,
394
- metadatas=metadatas,
395
- )
396
-
397
- return jsonify({"count": len(ids)})
398
-
399
-
400
- @app.route("/api/chromadb/purge", methods=["POST"])
401
- @require_module("chromadb")
402
- def chromadb_purge():
403
- data = request.get_json()
404
- if "chat_id" not in data or not isinstance(data["chat_id"], str):
405
- abort(400, '"chat_id" is required')
406
-
407
- chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
408
- collection = chromadb_client.get_or_create_collection(
409
- name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
410
- )
411
-
412
- deleted = collection.delete()
413
- print("ChromaDB embeddings deleted", len(deleted))
414
- return 'Ok', 200
415
-
416
-
417
- @app.route("/api/chromadb/query", methods=["POST"])
418
- @require_module("chromadb")
419
- def chromadb_query():
420
- data = request.get_json()
421
- if "chat_id" not in data or not isinstance(data["chat_id"], str):
422
- abort(400, '"chat_id" is required')
423
- if "query" not in data or not isinstance(data["query"], str):
424
- abort(400, '"query" is required')
425
-
426
- if "n_results" not in data or not isinstance(data["n_results"], int):
427
- n_results = 1
428
- else:
429
- n_results = data["n_results"]
430
-
431
- chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
432
- collection = chromadb_client.get_or_create_collection(
433
- name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
434
- )
435
-
436
- n_results = min(collection.count(), n_results)
437
- query_result = collection.query(
438
- query_texts=[data["query"]],
439
- n_results=n_results,
440
- )
441
-
442
- documents = query_result["documents"][0]
443
- ids = query_result["ids"][0]
444
- metadatas = query_result["metadatas"][0]
445
- distances = query_result["distances"][0]
446
-
447
- messages = [
448
- {
449
- "id": ids[i],
450
- "date": metadatas[i]["date"],
451
- "role": metadatas[i]["role"],
452
- "meta": metadatas[i]["meta"],
453
- "content": documents[i],
454
- "distance": distances[i],
455
- }
456
- for i in range(len(ids))
457
- ]
458
-
459
- return jsonify(messages)
460
-
461
- app.run(host=host, port=port)
 
 
1
  from flask import (
2
  Flask,
3
  jsonify,
4
  request,
5
  render_template_string,
6
  abort,
 
 
7
  )
8
  from flask_cors import CORS
 
 
 
 
 
9
  import unicodedata
10
+ import markdown
11
  import time
12
  import os
13
  import gc
 
14
  import base64
15
  from io import BytesIO
16
  from random import randint
 
17
  import hashlib
18
+ import chromadb
19
+ import posthog
20
+ from chromadb.config import Settings
21
+ from sentence_transformers import SentenceTransformer
22
+ from werkzeug.middleware.proxy_fix import ProxyFix
23
+ import argparse
24
+ from transformers import AutoTokenizer, AutoProcessor, pipeline
25
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
26
+ from transformers import BlipForConditionalGeneration, GPT2Tokenizer
27
+ from PIL import Image
28
+ import webuiapi
29
  from constants import *
30
  from colorama import Fore, Style, init as colorama_init
31
 
32
  colorama_init()
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  parser.add_argument(
35
  "--classification-model", help="Load a custom text classification model"
36
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  port = 7860
39
  host = "0.0.0.0"
 
49
  else DEFAULT_CLASSIFICATION_MODEL
50
  )
51
 
 
 
 
52
 
53
+ embedding_model = 'sentence-transformers/all-mpnet-base-v2'
 
 
 
 
54
 
55
+ print("Initializing a text summarization model...")
 
 
 
56
 
57
+ summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
58
+ summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
59
+ summarization_model, torch_dtype=torch_dtype).to(device)
 
 
 
60
 
61
+ print("Initializing a sentiment classification pipeline...")
62
+ classification_pipe = pipeline(
 
63
  "text-classification",
64
  model=classification_model,
65
  top_k=None,
 
67
  torch_dtype=torch_dtype,
68
  )
69
 
 
 
 
 
 
 
70
 
 
 
 
 
 
71
 
72
+ print("Initializing ChromaDB")
73
+
74
+ device_string = "cpu"
75
+ device = torch.device(device_string)
76
+ torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
77
+
78
+
79
+
80
+ # disable chromadb telemetry
81
+ posthog.capture = lambda *args, **kwargs: None
82
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
83
+ chromadb_embedder = SentenceTransformer(embedding_model)
84
+ chromadb_embed_fn = chromadb_embedder.encode
85
 
86
  # Flask init
87
  app = Flask(__name__)
88
  CORS(app) # allow cross-domain requests
89
  app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
90
 
91
+ app.wsgi_app = ProxyFix(
92
+ app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1
93
+ )
94
 
95
+ def get_real_ip():
96
+ return request.remote_addr
 
 
 
 
 
 
 
 
 
 
97
 
 
98
  def classify_text(text: str) -> list:
99
  output = classification_pipe(
100
  text,
 
161
  response.headers["X-Request-Duration"] = str(duration)
162
  return response
163
 
 
164
  @app.route("/", methods=["GET"])
165
  def index():
166
  with open("./README.md", "r", encoding="utf8") as f:
 
168
  return render_template_string(markdown.markdown(content, extensions=["tables"]))
169
 
170
 
171
+ @app.route("/api/modules", methods=["GET"])
172
+ def get_modules():
173
+ return jsonify({"modules": ['chromadb']})
174
+
175
+ @app.route("/api/chromadb", methods=["POST"])
176
+ def chromadb_add_messages():
177
+ data = request.get_json()
178
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
179
+ abort(400, '"chat_id" is required')
180
+ if "messages" not in data or not isinstance(data["messages"], list):
181
+ abort(400, '"messages" is required')
182
+
183
+ ip = get_real_ip()
184
+ chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
185
+ collection = chromadb_client.get_or_create_collection(
186
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
187
  )
188
+
189
+ documents = [m["content"] for m in data["messages"]]
190
+ ids = [m["id"] for m in data["messages"]]
191
+ metadatas = [
192
+ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
193
+ for m in data["messages"]
194
+ ]
195
+
196
+ if len(ids) > 0:
197
+ collection.upsert(
198
+ ids=ids,
199
+ documents=documents,
200
+ metadatas=metadatas,
201
+ )
202
+
203
+ return jsonify({"count": len(ids)})
204
 
205
 
206
+ @app.route("/api/chromadb/query", methods=["POST"])
207
+ def chromadb_query():
 
208
  data = request.get_json()
209
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
210
+ abort(400, '"chat_id" is required')
211
+ if "query" not in data or not isinstance(data["query"], str):
212
+ abort(400, '"query" is required')
213
+
214
+ if "n_results" not in data or not isinstance(data["n_results"], int):
215
+ n_results = 1
216
+ else:
217
+ n_results = data["n_results"]
218
 
219
+ ip = get_real_ip()
220
+ chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
221
+ collection = chromadb_client.get_or_create_collection(
222
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
223
+ )
224
 
225
+ n_results = min(collection.count(), n_results)
 
 
 
 
 
 
 
226
 
227
+ messages = []
228
+ if n_results > 0:
229
+ query_result = collection.query(
230
+ query_texts=[data["query"]],
231
+ n_results=n_results,
232
+ )
233
+
234
+ documents = query_result["documents"][0]
235
+ ids = query_result["ids"][0]
236
+ metadatas = query_result["metadatas"][0]
237
+ distances = query_result["distances"][0]
238
+
239
+ messages = [
240
+ {
241
+ "id": ids[i],
242
+ "date": metadatas[i]["date"],
243
+ "role": metadatas[i]["role"],
244
+ "meta": metadatas[i]["meta"],
245
+ "content": documents[i],
246
+ "distance": distances[i],
247
+ }
248
+ for i in range(len(ids))
249
+ ]
250
+
251
+ return jsonify(messages)
252
+
253
+ @app.route("/api/chromadb/purge", methods=["POST"])
254
+ def chromadb_purge():
255
+ data = request.get_json()
256
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
257
+ abort(400, '"chat_id" is required')
258
+
259
+ ip = get_real_ip()
260
+ chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
261
+ collection = chromadb_client.get_or_create_collection(
262
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
263
+ )
264
+
265
+ deleted = collection.delete()
266
+ print("ChromaDB embeddings deleted", len(deleted))
267
+
268
+ return 'Ok', 200
269
 
270
  @app.route("/api/summarize", methods=["POST"])
271
  @require_module("summarize")
 
287
  return jsonify({"summary": summary})
288
 
289
 
290
+
291
  @app.route("/api/classify", methods=["POST"])
 
292
  def api_classify():
293
  data = request.get_json()
294
 
 
303
 
304
 
305
  @app.route("/api/classify/labels", methods=["GET"])
 
306
  def api_classify_labels():
307
  classification = classify_text("")
308
  labels = [x["label"] for x in classification]
309
  return jsonify({"labels": labels})
310
 
311
 
312
+ app.run(host=host, port=port)