Kannon commited on
Commit
bae8bb5
·
1 Parent(s): df9340c

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +343 -357
server.py CHANGED
@@ -1,378 +1,364 @@
1
- from functools import wraps
2
- from flask import (
3
- Flask,
4
- jsonify,
5
- request,
6
- render_template_string,
7
- abort,
8
- send_from_directory,
9
- send_file,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  )
11
- from flask_cors import CORS
12
- import unicodedata
13
- import markdown
14
- import time
15
- import os
16
- import gc
17
- import base64
18
- from io import BytesIO
19
- from random import randint
20
- import hashlib
21
- import chromadb
22
- import posthog
23
- import torch
24
- from chromadb.config import Settings
25
- from sentence_transformers import SentenceTransformer
26
- from werkzeug.middleware.proxy_fix import ProxyFix
27
- from transformers import AutoTokenizer, AutoProcessor, pipeline
28
- from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
29
- from transformers import BlipForConditionalGeneration, GPT2Tokenizer
30
- from PIL import Image
31
- import webuiapi
32
- from colorama import Fore, Style, init as colorama_init
33
 
34
-
35
-
36
-
37
- colorama_init()
38
-
39
- port = 7860
40
- host = "0.0.0.0"
41
-
42
-
43
-
44
- args = parser.parse_args()
45
-
46
-
47
- summarization_model = (
48
- "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
49
- )
50
- classification_model = (
51
- "joeddav/distilbert-base-uncased-go-emotions-student"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  )
53
-
54
- captioning_model = (
55
- "Salesforce/blip-image-captioning-large"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  )
 
 
 
 
 
57
 
58
- print("Initializing an image captioning model...")
59
- captioning_processor = AutoProcessor.from_pretrained(captioning_model)
60
- if "blip" in captioning_model:
61
- captioning_transformer = BlipForConditionalGeneration.from_pretrained(
62
- captioning_model, torch_dtype=torch_dtype
63
- ).to(device)
64
- else:
65
- captioning_transformer = AutoModelForCausalLM.from_pretrained(
66
- captioning_model, torch_dtype=torch_dtype
67
- ).to(device)
68
-
69
-
70
-
71
- device_string = "cpu"
72
- device = torch.device(device_string)
73
- torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
74
-
75
- embedding_model = 'sentence-transformers/all-mpnet-base-v2'
76
 
77
- print("Initializing a text summarization model...")
 
 
78
 
79
- summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
80
- summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
81
- summarization_model, torch_dtype=torch_dtype).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- print("Initializing a sentiment classification pipeline...")
84
- classification_pipe = pipeline(
85
- "text-classification",
86
- model=classification_model,
87
- top_k=None,
88
- device=device,
89
- torch_dtype=torch_dtype,
 
 
 
 
 
90
  )
91
 
 
92
 
93
 
94
- print("Initializing ChromaDB")
 
 
 
 
 
 
95
 
96
- # disable chromadb telemetry
97
- posthog.capture = lambda *args, **kwargs: None
98
- chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
99
- chromadb_embedder = SentenceTransformer(embedding_model)
100
- chromadb_embed_fn = chromadb_embedder.encode
101
 
102
- # Flask init
103
- app = Flask(__name__)
104
- CORS(app) # allow cross-domain requests
105
- app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
106
-
107
- app.wsgi_app = ProxyFix(
108
- app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1
109
  )
110
 
111
- def get_real_ip():
112
- return request.remote_addr
113
-
114
- def classify_text(text: str) -> list:
115
- output = classification_pipe(
116
- text,
117
- truncation=True,
118
- max_length=classification_pipe.model.config.max_position_embeddings,
119
- )[0]
120
- return sorted(output, key=lambda x: x["score"], reverse=True)
121
-
122
- def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
123
- inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
124
- device, torch_dtype
125
- )
126
- outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
127
- caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
128
- return caption
129
-
130
-
131
-
132
- def summarize_chunks(text: str, params: dict) -> str:
133
- try:
134
- return summarize(text, params)
135
- except IndexError:
136
- print(
137
- "Sequence length too large for model, cutting text in half and calling again"
138
- )
139
- new_params = params.copy()
140
- new_params["max_length"] = new_params["max_length"] // 2
141
- new_params["min_length"] = new_params["min_length"] // 2
142
- return summarize_chunks(
143
- text[: (len(text) // 2)], new_params
144
- ) + summarize_chunks(text[(len(text) // 2) :], new_params)
145
-
146
-
147
- def summarize(text: str, params: dict) -> str:
148
- # Tokenize input
149
- inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
150
- token_count = len(inputs[0])
151
-
152
- bad_words_ids = [
153
- summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
154
- for bad_word in params["bad_words"]
155
- ]
156
- summary_ids = summarization_transformer.generate(
157
- inputs["input_ids"],
158
- num_beams=2,
159
- max_new_tokens=max(token_count, int(params["max_length"])),
160
- min_new_tokens=min(token_count, int(params["min_length"])),
161
- repetition_penalty=float(params["repetition_penalty"]),
162
- temperature=float(params["temperature"]),
163
- length_penalty=float(params["length_penalty"]),
164
- bad_words_ids=bad_words_ids,
165
- )
166
- summary = summarization_tokenizer.batch_decode(
167
- summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
168
- )[0]
169
- summary = normalize_string(summary)
170
- return summary
171
-
172
-
173
- def normalize_string(input: str) -> str:
174
- output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
175
- return output
176
 
177
- @app.before_request
178
- # Request time measuring
179
- def before_request():
180
- request.start_time = time.time()
181
-
182
-
183
- @app.after_request
184
- def after_request(response):
185
- duration = time.time() - request.start_time
186
- response.headers["X-Request-Duration"] = str(duration)
187
- return response
188
-
189
- @app.route("/", methods=["GET"])
190
- def index():
191
- with open("./README.md", "r", encoding="utf8") as f:
192
- content = f.read()
193
- return render_template_string(markdown.markdown(content, extensions=["tables"]))
194
-
195
-
196
- @app.route("/api/modules", methods=["GET"])
197
- def get_modules():
198
- return jsonify({"modules": ['chromadb','summarize','classify']})
199
-
200
- @app.route("/api/chromadb", methods=["POST"])
201
- def chromadb_add_messages():
202
- data = request.get_json()
203
- if "chat_id" not in data or not isinstance(data["chat_id"], str):
204
- abort(400, '"chat_id" is required')
205
- if "messages" not in data or not isinstance(data["messages"], list):
206
- abort(400, '"messages" is required')
207
-
208
- ip = get_real_ip()
209
- chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
210
- collection = chromadb_client.get_or_create_collection(
211
- name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
212
  )
213
-
214
- documents = [m["content"] for m in data["messages"]]
215
- ids = [m["id"] for m in data["messages"]]
216
- metadatas = [
217
- {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
218
- for m in data["messages"]
 
 
 
 
 
 
 
 
 
 
219
  ]
220
 
221
- if len(ids) > 0:
222
- collection.upsert(
223
- ids=ids,
224
- documents=documents,
225
- metadatas=metadatas,
226
- )
227
-
228
- return jsonify({"count": len(ids)})
229
-
230
-
231
- @app.route("/api/chromadb/query", methods=["POST"])
232
- def chromadb_query():
233
- data = request.get_json()
234
- if "chat_id" not in data or not isinstance(data["chat_id"], str):
235
- abort(400, '"chat_id" is required')
236
- if "query" not in data or not isinstance(data["query"], str):
237
- abort(400, '"query" is required')
238
-
239
- if "n_results" not in data or not isinstance(data["n_results"], int):
240
- n_results = 1
241
- else:
242
- n_results = data["n_results"]
243
-
244
- ip = get_real_ip()
245
- chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
246
- collection = chromadb_client.get_or_create_collection(
247
- name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
248
- )
249
-
250
- n_results = min(collection.count(), n_results)
251
-
252
- messages = []
253
- if n_results > 0:
254
- query_result = collection.query(
255
- query_texts=[data["query"]],
256
- n_results=n_results,
257
- )
258
-
259
- documents = query_result["documents"][0]
260
- ids = query_result["ids"][0]
261
- metadatas = query_result["metadatas"][0]
262
- distances = query_result["distances"][0]
263
-
264
- messages = [
265
- {
266
- "id": ids[i],
267
- "date": metadatas[i]["date"],
268
- "role": metadatas[i]["role"],
269
- "meta": metadatas[i]["meta"],
270
- "content": documents[i],
271
- "distance": distances[i],
272
- }
273
- for i in range(len(ids))
274
- ]
275
-
276
- return jsonify(messages)
277
-
278
- @app.route("/api/chromadb/purge", methods=["POST"])
279
- def chromadb_purge():
280
- data = request.get_json()
281
- if "chat_id" not in data or not isinstance(data["chat_id"], str):
282
- abort(400, '"chat_id" is required')
283
-
284
- ip = get_real_ip()
285
- chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
286
- collection = chromadb_client.get_or_create_collection(
287
- name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
288
- )
289
 
290
- deleted = collection.delete()
291
- print("ChromaDB embeddings deleted", len(deleted))
292
-
293
- return 'Ok', 200
294
-
295
- @app.route("/api/caption", methods=["POST"])
296
- def api_caption():
297
- data = request.get_json()
298
-
299
- if "image" not in data or not isinstance(data["image"], str):
300
- abort(400, '"image" is required')
301
-
302
- image = Image.open(BytesIO(base64.b64decode(data["image"])))
303
- image = image.convert("RGB")
304
- image.thumbnail((512, 512))
305
- caption = caption_image(image)
306
- thumbnail = image_to_base64(image)
307
- print("Caption:", caption, sep="\n")
308
- gc.collect()
309
- return jsonify({"caption": caption, "thumbnail": thumbnail})
310
-
311
-
312
- @app.route("/api/summarize", methods=["POST"])
313
- def api_summarize():
314
- data = request.get_json()
315
-
316
- if "text" not in data or not isinstance(data["text"], str):
317
- abort(400, '"text" is required')
318
-
319
- params = {
320
- "temperature": 1.0,
321
- "repetition_penalty": 1.0,
322
- "max_length": 500,
323
- "min_length": 200,
324
- "length_penalty": 1.5,
325
- "bad_words": [
326
- "\n",
327
- '"',
328
- "*",
329
- "[",
330
- "]",
331
- "{",
332
- "}",
333
- ":",
334
- "(",
335
- ")",
336
- "<",
337
- ">",
338
- "Â",
339
- "The text ends",
340
- "The story ends",
341
- "The text is",
342
- "The story is",
343
- ],
344
- }
345
-
346
- if "params" in data and isinstance(data["params"], dict):
347
- params.update(data["params"])
348
-
349
- print("Summary input:", data["text"], sep="\n")
350
- summary = summarize_chunks(data["text"], params)
351
- print("Summary output:", summary, sep="\n")
352
- gc.collect()
353
- return jsonify({"summary": summary})
354
-
355
-
356
-
357
- @app.route("/api/classify", methods=["POST"])
358
- def api_classify():
359
- data = request.get_json()
360
-
361
- if "text" not in data or not isinstance(data["text"], str):
362
- abort(400, '"text" is required')
363
-
364
- print("Classification input:", data["text"], sep="\n")
365
- classification = classify_text(data["text"])
366
- print("Classification output:", classification, sep="\n")
367
- gc.collect()
368
- return jsonify({"classification": classification})
369
-
370
-
371
- @app.route("/api/classify/labels", methods=["GET"])
372
- def api_classify_labels():
373
- classification = classify_text("")
374
- labels = [x["label"] for x in classification]
375
- return jsonify({"labels": labels})
376
-
377
-
378
- app.run(host=host, port=port)
 
1
+ from functools import wraps
2
+ from flask import (
3
+ Flask,
4
+ jsonify,
5
+ request,
6
+ render_template_string,
7
+ abort,
8
+ send_from_directory,
9
+ send_file,
10
+ )
11
+ from flask_cors import CORS
12
+ import unicodedata
13
+ import markdown
14
+ import time
15
+ import os
16
+ import gc
17
+ import base64
18
+ from io import BytesIO
19
+ from random import randint
20
+ import hashlib
21
+ import chromadb
22
+ import posthog
23
+ import torch
24
+ from chromadb.config import Settings
25
+ from sentence_transformers import SentenceTransformer
26
+ from werkzeug.middleware.proxy_fix import ProxyFix
27
+ from transformers import AutoTokenizer, AutoProcessor, pipeline
28
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
29
+ from transformers import BlipForConditionalGeneration, GPT2Tokenizer
30
+ from PIL import Image
31
+ import webuiapi
32
+ from colorama import Fore, Style, init as colorama_init
33
+
34
+
35
+ colorama_init()
36
+
37
+ port = 7860
38
+ host = "0.0.0.0"
39
+
40
+
41
+ summarization_model = (
42
+ "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
43
+ )
44
+ classification_model = (
45
+ "joeddav/distilbert-base-uncased-go-emotions-student"
46
+ )
47
+ captioning_model = (
48
+ "Salesforce/blip-image-captioning-large"
49
+ )
50
+
51
+ device_string = "cpu"
52
+ device = torch.device(device_string)
53
+ torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
54
+
55
+ embedding_model = 'sentence-transformers/all-mpnet-base-v2'
56
+
57
+ print("Initializing a text summarization model...")
58
+
59
+ summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
60
+ summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
61
+ summarization_model, torch_dtype=torch_dtype).to(device)
62
+
63
+ print("Initializing a sentiment classification pipeline...")
64
+ classification_pipe = pipeline(
65
+ "text-classification",
66
+ model=classification_model,
67
+ top_k=None,
68
+ device=device,
69
+ torch_dtype=torch_dtype,
70
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ print("Initializing an image captioning model...")
73
+ captioning_processor = AutoProcessor.from_pretrained(captioning_model)
74
+ if "blip" in captioning_model:
75
+ captioning_transformer = BlipForConditionalGeneration.from_pretrained(
76
+ captioning_model, torch_dtype=torch_dtype
77
+ ).to(device)
78
+ else:
79
+ captioning_transformer = AutoModelForCausalLM.from_pretrained(
80
+ captioning_model, torch_dtype=torch_dtype
81
+ ).to(device)
82
+
83
+
84
+ print("Initializing ChromaDB")
85
+
86
+ # disable chromadb telemetry
87
+ posthog.capture = lambda *args, **kwargs: None
88
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
89
+ chromadb_embedder = SentenceTransformer(embedding_model)
90
+ chromadb_embed_fn = chromadb_embedder.encode
91
+
92
+ # Flask init
93
+ app = Flask(__name__)
94
+ CORS(app) # allow cross-domain requests
95
+ app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
96
+
97
+ app.wsgi_app = ProxyFix(
98
+ app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1
99
+ )
100
+
101
+ def get_real_ip():
102
+ return request.remote_addr
103
+
104
+ def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
105
+ inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
106
+ device, torch_dtype
107
  )
108
+ outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
109
+ caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
110
+ return caption
111
+
112
+
113
+ def classify_text(text: str) -> list:
114
+ output = classification_pipe(
115
+ text,
116
+ truncation=True,
117
+ max_length=classification_pipe.model.config.max_position_embeddings,
118
+ )[0]
119
+ return sorted(output, key=lambda x: x["score"], reverse=True)
120
+
121
+
122
+ def summarize_chunks(text: str, params: dict) -> str:
123
+ try:
124
+ return summarize(text, params)
125
+ except IndexError:
126
+ print(
127
+ "Sequence length too large for model, cutting text in half and calling again"
128
+ )
129
+ new_params = params.copy()
130
+ new_params["max_length"] = new_params["max_length"] // 2
131
+ new_params["min_length"] = new_params["min_length"] // 2
132
+ return summarize_chunks(
133
+ text[: (len(text) // 2)], new_params
134
+ ) + summarize_chunks(text[(len(text) // 2) :], new_params)
135
+
136
+
137
+ def summarize(text: str, params: dict) -> str:
138
+ # Tokenize input
139
+ inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
140
+ token_count = len(inputs[0])
141
+
142
+ bad_words_ids = [
143
+ summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
144
+ for bad_word in params["bad_words"]
145
+ ]
146
+ summary_ids = summarization_transformer.generate(
147
+ inputs["input_ids"],
148
+ num_beams=2,
149
+ max_new_tokens=max(token_count, int(params["max_length"])),
150
+ min_new_tokens=min(token_count, int(params["min_length"])),
151
+ repetition_penalty=float(params["repetition_penalty"]),
152
+ temperature=float(params["temperature"]),
153
+ length_penalty=float(params["length_penalty"]),
154
+ bad_words_ids=bad_words_ids,
155
  )
156
+ summary = summarization_tokenizer.batch_decode(
157
+ summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
158
+ )[0]
159
+ summary = normalize_string(summary)
160
+ return summary
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ def normalize_string(input: str) -> str:
164
+ output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
165
+ return output
166
 
167
+ @app.before_request
168
+ def before_request():
169
+ # Request time measuring
170
+ request.start_time = time.time()
171
+
172
+ # Checks if an API key is present and valid, otherwise return unauthorized
173
+ # The options check is required so CORS doesn't get angry
174
+ try:
175
+ if request.method != 'OPTIONS' and getattr(request.authorization, 'token', '') != os.environ['sekrit_password']:
176
+ print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
177
+ response = jsonify({ 'error': '401: Invalid API key' })
178
+ response.status_code = 401
179
+ return response
180
+ except Exception as e:
181
+ print(f"API key check error: {e}")
182
+ return "401 Unauthorized\n{}\n\n".format(e), 401
183
+
184
+
185
+
186
+ @app.after_request
187
+ def after_request(response):
188
+ duration = time.time() - request.start_time
189
+ response.headers["X-Request-Duration"] = str(duration)
190
+ return response
191
+
192
+ @app.route("/", methods=["GET"])
193
+ def index():
194
+ with open("./README.md", "r", encoding="utf8") as f:
195
+ content = f.read()
196
+ return render_template_string(markdown.markdown(content, extensions=["tables"]))
197
+
198
+
199
+ @app.route("/api/modules", methods=["GET"])
200
+ def get_modules():
201
+ return jsonify({"modules": ['chromadb','summarize','classify']})
202
+
203
+ @app.route("/api/chromadb", methods=["POST"])
204
+ def chromadb_add_messages():
205
+ data = request.get_json()
206
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
207
+ abort(400, '"chat_id" is required')
208
+ if "messages" not in data or not isinstance(data["messages"], list):
209
+ abort(400, '"messages" is required')
210
+
211
+ ip = get_real_ip()
212
+ chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
213
+ collection = chromadb_client.get_or_create_collection(
214
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
215
+ )
216
 
217
+ documents = [m["content"] for m in data["messages"]]
218
+ ids = [m["id"] for m in data["messages"]]
219
+ metadatas = [
220
+ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
221
+ for m in data["messages"]
222
+ ]
223
+
224
+ if len(ids) > 0:
225
+ collection.upsert(
226
+ ids=ids,
227
+ documents=documents,
228
+ metadatas=metadatas,
229
  )
230
 
231
+ return jsonify({"count": len(ids)})
232
 
233
 
234
+ @app.route("/api/chromadb/query", methods=["POST"])
235
+ def chromadb_query():
236
+ data = request.get_json()
237
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
238
+ abort(400, '"chat_id" is required')
239
+ if "query" not in data or not isinstance(data["query"], str):
240
+ abort(400, '"query" is required')
241
 
242
+ if "n_results" not in data or not isinstance(data["n_results"], int):
243
+ n_results = 1
244
+ else:
245
+ n_results = data["n_results"]
 
246
 
247
+ ip = get_real_ip()
248
+ chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
249
+ collection = chromadb_client.get_or_create_collection(
250
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
 
 
 
251
  )
252
 
253
+ n_results = min(collection.count(), n_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
+ messages = []
256
+ if n_results > 0:
257
+ query_result = collection.query(
258
+ query_texts=[data["query"]],
259
+ n_results=n_results,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  )
261
+
262
+ documents = query_result["documents"][0]
263
+ ids = query_result["ids"][0]
264
+ metadatas = query_result["metadatas"][0]
265
+ distances = query_result["distances"][0]
266
+
267
+ messages = [
268
+ {
269
+ "id": ids[i],
270
+ "date": metadatas[i]["date"],
271
+ "role": metadatas[i]["role"],
272
+ "meta": metadatas[i]["meta"],
273
+ "content": documents[i],
274
+ "distance": distances[i],
275
+ }
276
+ for i in range(len(ids))
277
  ]
278
 
279
+ return jsonify(messages)
280
+
281
+ @app.route("/api/chromadb/purge", methods=["POST"])
282
+ def chromadb_purge():
283
+ data = request.get_json()
284
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
285
+ abort(400, '"chat_id" is required')
286
+
287
+ ip = get_real_ip()
288
+ chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
289
+ collection = chromadb_client.get_or_create_collection(
290
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
291
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ deleted = collection.delete()
294
+ print("ChromaDB embeddings deleted", len(deleted))
295
+
296
+ return 'Ok', 200
297
+
298
+ @app.route("/api/summarize", methods=["POST"])
299
+ def api_summarize():
300
+ data = request.get_json()
301
+
302
+ if "text" not in data or not isinstance(data["text"], str):
303
+ abort(400, '"text" is required')
304
+
305
+ params = {
306
+ "temperature": 1.0,
307
+ "repetition_penalty": 1.0,
308
+ "max_length": 500,
309
+ "min_length": 200,
310
+ "length_penalty": 1.5,
311
+ "bad_words": [
312
+ "\n",
313
+ '"',
314
+ "*",
315
+ "[",
316
+ "]",
317
+ "{",
318
+ "}",
319
+ ":",
320
+ "(",
321
+ ")",
322
+ "<",
323
+ ">",
324
+ "Â",
325
+ "The text ends",
326
+ "The story ends",
327
+ "The text is",
328
+ "The story is",
329
+ ],
330
+ }
331
+
332
+ if "params" in data and isinstance(data["params"], dict):
333
+ params.update(data["params"])
334
+
335
+ print("Summary input:", data["text"], sep="\n")
336
+ summary = summarize_chunks(data["text"], params)
337
+ print("Summary output:", summary, sep="\n")
338
+ gc.collect()
339
+ return jsonify({"summary": summary})
340
+
341
+
342
+
343
+ @app.route("/api/classify", methods=["POST"])
344
+ def api_classify():
345
+ data = request.get_json()
346
+
347
+ if "text" not in data or not isinstance(data["text"], str):
348
+ abort(400, '"text" is required')
349
+
350
+ print("Classification input:", data["text"], sep="\n")
351
+ classification = classify_text(data["text"])
352
+ print("Classification output:", classification, sep="\n")
353
+ gc.collect()
354
+ return jsonify({"classification": classification})
355
+
356
+
357
+ @app.route("/api/classify/labels", methods=["GET"])
358
+ def api_classify_labels():
359
+ classification = classify_text("")
360
+ labels = [x["label"] for x in classification]
361
+ return jsonify({"labels": labels})
362
+
363
+
364
+ app.run(host=host, port=port)