doctord98 commited on
Commit
c818fcf
·
1 Parent(s): 5ed778c

Upload 6 files

Browse files
Files changed (6) hide show
  1. Dockerfile +20 -0
  2. README.md +10 -0
  3. constants.py +50 -0
  4. requirements.txt +18 -0
  5. server.py +844 -0
  6. tts_edge.py +34 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install -r requirements.txt
7
+
8
+ RUN mkdir /.cache && chmod -R 777 /.cache
9
+ RUN mkdir .chroma && chmod -R 777 .chroma
10
+
11
+ COPY . .
12
+
13
+ RUN chmod -R 777 /app
14
+
15
+ RUN --mount=type=secret,id=password,mode=0444,required=true \
16
+ cat /run/secrets/password > /test
17
+
18
+ EXPOSE 7860
19
+
20
+ CMD ["python", "server.py", "--cpu", "--enable-modules=caption,summarize,classify,silero-tts,edge-tts,chromadb"]
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: smut
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+ doctord98 is your lord and savior
constants.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Constants
2
+ # Also try: 'slauw87/bart-large-cnn-samsum'
3
+ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ElectrifAi_v14"
4
+ # Also try: 'nateraw/bert-base-uncased-emotion'
5
+ DEFAULT_CLASSIFICATION_MODEL = "joeddav/distilbert-base-uncased-go-emotions-student"
6
+ # Also try: 'Salesforce/blip-image-captioning-base'
7
+ DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
8
+ # Also try: 'ckpt/anything-v4.5-vae-swapped'
9
+ DEFAULT_SD_MODEL = "sinkinai/MeinaHentai-v3-baked-vae"
10
+ DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
11
+ DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
12
+ DEFAULT_REMOTE_SD_PORT = 7860
13
+ DEFAULT_CHROMA_PORT = 8000
14
+ SILERO_SAMPLES_PATH = "tts_samples"
15
+ SILERO_SAMPLE_TEXT = "Doctor D 98 is your lord and savior"
16
+ # ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
17
+ DEFAULT_SUMMARIZE_PARAMS = {
18
+ "temperature": 1.0,
19
+ "repetition_penalty": 1.0,
20
+ "max_length": 500,
21
+ "min_length": 200,
22
+ "length_penalty": 1.5,
23
+ "bad_words": [
24
+ "\n",
25
+ '"',
26
+ "*",
27
+ "[",
28
+ "]",
29
+ "{",
30
+ "}",
31
+ ":",
32
+ "(",
33
+ ")",
34
+ "<",
35
+ ">",
36
+ "Â",
37
+ "The text ends",
38
+ "The story ends",
39
+ "The text is",
40
+ "The story is",
41
+ ],
42
+ }
43
+
44
+ PROMPT_PREFIX = "best quality, absurdres, "
45
+ NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
46
+ error hands, bad hands, error fingers, bad fingers, missing fingers
47
+ error legs, bad legs, multiple legs, missing legs, error lighting,
48
+ error shadow, error reflection, text, error, extra digit, fewer digits,
49
+ cropped, worst quality, low quality, normal quality, jpeg artifacts,
50
+ signature, watermark, username, blurry"""
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ flask-compress
4
+ markdown
5
+ Pillow
6
+ colorama
7
+ webuiapi
8
+ --extra-index-url https://download.pytorch.org/whl/cu117
9
+ torch==2.0.0+cu117
10
+ torchvision==0.15.1
11
+ torchaudio==2.0.1+cu117
12
+ accelerate
13
+ transformers==4.28.1
14
+ diffusers==0.16.1
15
+ silero-api-server
16
+ chromadb
17
+ sentence_transformers
18
+ edge-tts
server.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from flask import (
3
+ Flask,
4
+ jsonify,
5
+ request,
6
+ Response,
7
+ render_template_string,
8
+ abort,
9
+ send_from_directory,
10
+ send_file,
11
+ )
12
+ from flask_cors import CORS
13
+ from flask_compress import Compress
14
+ import markdown
15
+ import argparse
16
+ from transformers import AutoTokenizer, AutoProcessor, pipeline
17
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
18
+ from transformers import BlipForConditionalGeneration
19
+ import unicodedata
20
+ import torch
21
+ import time
22
+ import os
23
+ import gc
24
+ import secrets
25
+ from PIL import Image
26
+ import base64
27
+ from io import BytesIO
28
+ from random import randint
29
+ import webuiapi
30
+ import hashlib
31
+ from constants import *
32
+ from colorama import Fore, Style, init as colorama_init
33
+
34
+ colorama_init()
35
+
36
+
37
+ class SplitArgs(argparse.Action):
38
+ def __call__(self, parser, namespace, values, option_string=None):
39
+ setattr(
40
+ namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
41
+ )
42
+
43
+
44
+ # Script arguments
45
+ parser = argparse.ArgumentParser(
46
+ prog="SillyTavern Extras", description="Web API for transformers models"
47
+ )
48
+ parser.add_argument(
49
+ "--port", type=int, help="Specify the port on which the application is hosted"
50
+ )
51
+ parser.add_argument(
52
+ "--listen", action="store_true", help="Host the app on the local network"
53
+ )
54
+ parser.add_argument(
55
+ "--share", action="store_true", help="Share the app on CloudFlare tunnel"
56
+ )
57
+ parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
58
+ parser.add_argument("--cuda", action="store_false", dest="cpu", help="Run the models on the GPU")
59
+ parser.set_defaults(cpu=True)
60
+ parser.add_argument("--summarization-model", help="Load a custom summarization model")
61
+ parser.add_argument(
62
+ "--classification-model", help="Load a custom text classification model"
63
+ )
64
+ parser.add_argument("--captioning-model", help="Load a custom captioning model")
65
+ parser.add_argument("--embedding-model", help="Load a custom text embedding model")
66
+ parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
67
+ parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
68
+ parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
69
+ parser.add_argument('--chroma-persist', help="Chromadb persistence", default=True, action=argparse.BooleanOptionalAction)
70
+ parser.add_argument(
71
+ "--secure", action="store_true", help="Enforces the use of an API key"
72
+ )
73
+
74
+ sd_group = parser.add_mutually_exclusive_group()
75
+
76
+ local_sd = sd_group.add_argument_group("sd-local")
77
+ local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
78
+ local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
79
+
80
+ remote_sd = sd_group.add_argument_group("sd-remote")
81
+ remote_sd.add_argument(
82
+ "--sd-remote", action="store_true", help="Use a remote backend for SD"
83
+ )
84
+ remote_sd.add_argument(
85
+ "--sd-remote-host", type=str, help="Specify the host of the remote SD backend"
86
+ )
87
+ remote_sd.add_argument(
88
+ "--sd-remote-port", type=int, help="Specify the port of the remote SD backend"
89
+ )
90
+ remote_sd.add_argument(
91
+ "--sd-remote-ssl", action="store_true", help="Use SSL for the remote SD backend"
92
+ )
93
+ remote_sd.add_argument(
94
+ "--sd-remote-auth",
95
+ type=str,
96
+ help="Specify the username:password for the remote SD backend (if required)",
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--enable-modules",
101
+ action=SplitArgs,
102
+ default=[],
103
+ help="Override a list of enabled modules",
104
+ )
105
+
106
+ args = parser.parse_args()
107
+
108
+ port = 7860
109
+ host = "0.0.0.0"
110
+ summarization_model = (
111
+ args.summarization_model
112
+ if args.summarization_model
113
+ else DEFAULT_SUMMARIZATION_MODEL
114
+ )
115
+ classification_model = (
116
+ args.classification_model
117
+ if args.classification_model
118
+ else DEFAULT_CLASSIFICATION_MODEL
119
+ )
120
+ captioning_model = (
121
+ args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
122
+ )
123
+ embedding_model = (
124
+ args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
125
+ )
126
+
127
+ sd_use_remote = False if args.sd_model else True
128
+ sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
129
+ sd_remote_host = args.sd_remote_host if args.sd_remote_host else DEFAULT_REMOTE_SD_HOST
130
+ sd_remote_port = args.sd_remote_port if args.sd_remote_port else DEFAULT_REMOTE_SD_PORT
131
+ sd_remote_ssl = args.sd_remote_ssl
132
+ sd_remote_auth = args.sd_remote_auth
133
+
134
+ modules = (
135
+ args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
136
+ )
137
+
138
+ if len(modules) == 0:
139
+ print(
140
+ f"{Fore.RED}{Style.BRIGHT}You did not select any modules to run! Choose them by adding an --enable-modules option"
141
+ )
142
+ print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
143
+
144
+ # Models init
145
+ device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu"
146
+ device = torch.device(device_string)
147
+ torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
148
+
149
+ if not torch.cuda.is_available() and not args.cpu:
150
+ print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device. Defaulting to CPU mode.{Style.RESET_ALL}")
151
+
152
+ print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
153
+
154
+ if "caption" in modules:
155
+ print("Initializing an image captioning model...")
156
+ captioning_processor = AutoProcessor.from_pretrained(captioning_model)
157
+ if "blip" in captioning_model:
158
+ captioning_transformer = BlipForConditionalGeneration.from_pretrained(
159
+ captioning_model, torch_dtype=torch_dtype
160
+ ).to(device)
161
+ else:
162
+ captioning_transformer = AutoModelForCausalLM.from_pretrained(
163
+ captioning_model, torch_dtype=torch_dtype
164
+ ).to(device)
165
+
166
+ if "summarize" in modules:
167
+ print("Initializing a text summarization model...")
168
+ summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
169
+ summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
170
+ summarization_model, torch_dtype=torch_dtype
171
+ ).to(device)
172
+
173
+ if "classify" in modules:
174
+ print("Initializing a sentiment classification pipeline...")
175
+ classification_pipe = pipeline(
176
+ "text-classification",
177
+ model=classification_model,
178
+ top_k=None,
179
+ device=device,
180
+ torch_dtype=torch_dtype,
181
+ )
182
+
183
+ if "sd" in modules and not sd_use_remote:
184
+ from diffusers import StableDiffusionPipeline
185
+ from diffusers import EulerAncestralDiscreteScheduler
186
+
187
+ print("Initializing Stable Diffusion pipeline")
188
+ sd_device_string = (
189
+ "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
190
+ )
191
+ sd_device = torch.device(sd_device_string)
192
+ sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
193
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
194
+ sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
195
+ ).to(sd_device)
196
+ sd_pipe.safety_checker = lambda images, clip_input: (images, False)
197
+ sd_pipe.enable_attention_slicing()
198
+ # pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
199
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
200
+ sd_pipe.scheduler.config
201
+ )
202
+ elif "sd" in modules and sd_use_remote:
203
+ print("Initializing Stable Diffusion connection")
204
+ try:
205
+ sd_remote = webuiapi.WebUIApi(
206
+ host=sd_remote_host, port=sd_remote_port, use_https=sd_remote_ssl
207
+ )
208
+ if sd_remote_auth:
209
+ username, password = sd_remote_auth.split(":")
210
+ sd_remote.set_auth(username, password)
211
+ sd_remote.util_wait_for_ready()
212
+ except Exception as e:
213
+ # remote sd from modules
214
+ print(
215
+ f"{Fore.RED}{Style.BRIGHT}Could not connect to remote SD backend at http{'s' if sd_remote_ssl else ''}://{sd_remote_host}:{sd_remote_port}! Disabling SD module...{Style.RESET_ALL}"
216
+ )
217
+ modules.remove("sd")
218
+
219
+ if "tts" in modules:
220
+ print("tts module is deprecated. Please use silero-tts instead.")
221
+ modules.remove("tts")
222
+ modules.append("silero-tts")
223
+
224
+
225
+ if "silero-tts" in modules:
226
+ if not os.path.exists(SILERO_SAMPLES_PATH):
227
+ os.makedirs(SILERO_SAMPLES_PATH)
228
+ print("Initializing Silero TTS server")
229
+ from silero_api_server import tts
230
+
231
+ tts_service = tts.SileroTtsService(SILERO_SAMPLES_PATH)
232
+ if len(os.listdir(SILERO_SAMPLES_PATH)) == 0:
233
+ print("Generating Silero TTS samples...")
234
+ tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
235
+ tts_service.generate_samples()
236
+
237
+
238
+ if "edge-tts" in modules:
239
+ print("Initializing Edge TTS client")
240
+ import tts_edge as edge
241
+
242
+
243
+ if "chromadb" in modules:
244
+ print("Initializing ChromaDB")
245
+ import chromadb
246
+ import posthog
247
+ from chromadb.config import Settings
248
+ from sentence_transformers import SentenceTransformer
249
+
250
+ # Assume that the user wants in-memory unless a host is specified
251
+ # Also disable chromadb telemetry
252
+ posthog.capture = lambda *args, **kwargs: None
253
+ if args.chroma_host is None:
254
+ if args.chroma_persist:
255
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False, persist_directory=args.chroma_folder, chroma_db_impl='duckdb+parquet'))
256
+ print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.")
257
+ else:
258
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
259
+ print(f"ChromaDB is running in-memory without persistence.")
260
+ else:
261
+ chroma_port=(
262
+ args.chroma_port if args.chroma_port else DEFAULT_CHROMA_PORT
263
+ )
264
+ chromadb_client = chromadb.Client(
265
+ Settings(
266
+ anonymized_telemetry=False,
267
+ chroma_api_impl="rest",
268
+ chroma_server_host=args.chroma_host,
269
+ chroma_server_http_port=chroma_port
270
+ )
271
+ )
272
+ print(f"ChromaDB is remotely configured at {args.chroma_host}:{chroma_port}")
273
+
274
+ chromadb_embedder = SentenceTransformer(embedding_model)
275
+ chromadb_embed_fn = lambda *args, **kwargs: chromadb_embedder.encode(*args, **kwargs).tolist()
276
+
277
+ # Check if the db is connected and running, otherwise tell the user
278
+ try:
279
+ chromadb_client.heartbeat()
280
+ print("Successfully pinged ChromaDB! Your client is successfully connected.")
281
+ except:
282
+ print("Could not ping ChromaDB! If you are running remotely, please check your host and port!")
283
+
284
+ # Flask init
285
+ app = Flask(__name__)
286
+ CORS(app) # allow cross-domain requests
287
+ Compress(app) # compress responses
288
+ app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
289
+
290
+
291
+ def require_module(name):
292
+ def wrapper(fn):
293
+ @wraps(fn)
294
+ def decorated_view(*args, **kwargs):
295
+ if name not in modules:
296
+ abort(403, "Module is disabled by config")
297
+ return fn(*args, **kwargs)
298
+
299
+ return decorated_view
300
+
301
+ return wrapper
302
+
303
+
304
+ # AI stuff
305
+ def classify_text(text: str) -> list:
306
+ output = classification_pipe(
307
+ text,
308
+ truncation=True,
309
+ max_length=classification_pipe.model.config.max_position_embeddings,
310
+ )[0]
311
+ return sorted(output, key=lambda x: x["score"], reverse=True)
312
+
313
+
314
+ def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
315
+ inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
316
+ device, torch_dtype
317
+ )
318
+ outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
319
+ caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
320
+ return caption
321
+
322
+
323
+ def summarize_chunks(text: str, params: dict) -> str:
324
+ try:
325
+ return summarize(text, params)
326
+ except IndexError:
327
+ print(
328
+ "Sequence length too large for model, cutting text in half and calling again"
329
+ )
330
+ new_params = params.copy()
331
+ new_params["max_length"] = new_params["max_length"] // 2
332
+ new_params["min_length"] = new_params["min_length"] // 2
333
+ return summarize_chunks(
334
+ text[: (len(text) // 2)], new_params
335
+ ) + summarize_chunks(text[(len(text) // 2) :], new_params)
336
+
337
+
338
+ def summarize(text: str, params: dict) -> str:
339
+ # Tokenize input
340
+ inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
341
+ token_count = len(inputs[0])
342
+
343
+ bad_words_ids = [
344
+ summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
345
+ for bad_word in params["bad_words"]
346
+ ]
347
+ summary_ids = summarization_transformer.generate(
348
+ inputs["input_ids"],
349
+ num_beams=2,
350
+ max_new_tokens=max(token_count, int(params["max_length"])),
351
+ min_new_tokens=min(token_count, int(params["min_length"])),
352
+ repetition_penalty=float(params["repetition_penalty"]),
353
+ temperature=float(params["temperature"]),
354
+ length_penalty=float(params["length_penalty"]),
355
+ bad_words_ids=bad_words_ids,
356
+ )
357
+ summary = summarization_tokenizer.batch_decode(
358
+ summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
359
+ )[0]
360
+ summary = normalize_string(summary)
361
+ return summary
362
+
363
+
364
+ def normalize_string(input: str) -> str:
365
+ output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
366
+ return output
367
+
368
+
369
+ def generate_image(data: dict) -> Image:
370
+ prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
371
+
372
+ if sd_use_remote:
373
+ image = sd_remote.txt2img(
374
+ prompt=prompt,
375
+ negative_prompt=data["negative_prompt"],
376
+ sampler_name=data["sampler"],
377
+ steps=data["steps"],
378
+ cfg_scale=data["scale"],
379
+ width=data["width"],
380
+ height=data["height"],
381
+ restore_faces=data["restore_faces"],
382
+ enable_hr=data["enable_hr"],
383
+ save_images=True,
384
+ send_images=True,
385
+ do_not_save_grid=False,
386
+ do_not_save_samples=False,
387
+ ).image
388
+ else:
389
+ image = sd_pipe(
390
+ prompt=prompt,
391
+ negative_prompt=data["negative_prompt"],
392
+ num_inference_steps=data["steps"],
393
+ guidance_scale=data["scale"],
394
+ width=data["width"],
395
+ height=data["height"],
396
+ ).images[0]
397
+
398
+ image.save("./debug.png")
399
+ return image
400
+
401
+
402
+ def image_to_base64(image: Image, quality: int = 75) -> str:
403
+ buffer = BytesIO()
404
+ image.convert("RGB")
405
+ image.save(buffer, format="JPEG", quality=quality)
406
+ img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
407
+ return img_str
408
+
409
+ api_key = os.environ.get("password")
410
+
411
+ @app.before_request
412
+ def before_request():
413
+ # Request time measuring
414
+ request.start_time = time.time()
415
+
416
+ # Checks if an API key is present and valid, otherwise return unauthorized
417
+ # The options check is required so CORS doesn't get angry
418
+ try:
419
+ if request.method != 'OPTIONS' and request.authorization.token != api_key:
420
+ print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
421
+ response = jsonify({ 'error': '401: Invalid API key' })
422
+ response.status_code = 401
423
+ return response
424
+ except Exception as e:
425
+ print(f"API key check error: {e}")
426
+ return "401 Unauthorized\n{}\n\n".format(e), 401
427
+
428
+
429
+ @app.after_request
430
+ def after_request(response):
431
+ duration = time.time() - request.start_time
432
+ response.headers["X-Request-Duration"] = str(duration)
433
+ return response
434
+
435
+
436
+ @app.route("/", methods=["GET"])
437
+ def index():
438
+ with open("./README.md", "r", encoding="utf8") as f:
439
+ content = f.read()
440
+ return render_template_string(markdown.markdown(content, extensions=["tables"]))
441
+
442
+
443
+ @app.route("/api/extensions", methods=["GET"])
444
+ def get_extensions():
445
+ extensions = dict(
446
+ {
447
+ "extensions": [
448
+ {
449
+ "name": "not-supported",
450
+ "metadata": {
451
+ "display_name": """<span style="white-space:break-spaces;">Extensions serving using Extensions API is no longer supported. Please update the mod from: <a href="https://github.com/Cohee1207/SillyTavern">https://github.com/Cohee1207/SillyTavern</a></span>""",
452
+ "requires": [],
453
+ "assets": [],
454
+ },
455
+ }
456
+ ]
457
+ }
458
+ )
459
+ return jsonify(extensions)
460
+
461
+
462
+ @app.route("/api/caption", methods=["POST"])
463
+ @require_module("caption")
464
+ def api_caption():
465
+ data = request.get_json()
466
+
467
+ if "image" not in data or not isinstance(data["image"], str):
468
+ abort(400, '"image" is required')
469
+
470
+ image = Image.open(BytesIO(base64.b64decode(data["image"])))
471
+ image = image.convert("RGB")
472
+ image.thumbnail((512, 512))
473
+ caption = caption_image(image)
474
+ thumbnail = image_to_base64(image)
475
+ print("Caption:", caption, sep="\n")
476
+ gc.collect()
477
+ return jsonify({"caption": caption, "thumbnail": thumbnail})
478
+
479
+
480
+ @app.route("/api/summarize", methods=["POST"])
481
+ @require_module("summarize")
482
+ def api_summarize():
483
+ data = request.get_json()
484
+
485
+ if "text" not in data or not isinstance(data["text"], str):
486
+ abort(400, '"text" is required')
487
+
488
+ params = DEFAULT_SUMMARIZE_PARAMS.copy()
489
+
490
+ if "params" in data and isinstance(data["params"], dict):
491
+ params.update(data["params"])
492
+
493
+ print("Summary input:", data["text"], sep="\n")
494
+ summary = summarize_chunks(data["text"], params)
495
+ print("Summary output:", summary, sep="\n")
496
+ gc.collect()
497
+ return jsonify({"summary": summary})
498
+
499
+
500
+ @app.route("/api/classify", methods=["POST"])
501
+ @require_module("classify")
502
+ def api_classify():
503
+ data = request.get_json()
504
+
505
+ if "text" not in data or not isinstance(data["text"], str):
506
+ abort(400, '"text" is required')
507
+
508
+ print("Classification input:", data["text"], sep="\n")
509
+ classification = classify_text(data["text"])
510
+ print("Classification output:", classification, sep="\n")
511
+ gc.collect()
512
+ return jsonify({"classification": classification})
513
+
514
+
515
+ @app.route("/api/classify/labels", methods=["GET"])
516
+ @require_module("classify")
517
+ def api_classify_labels():
518
+ classification = classify_text("")
519
+ labels = [x["label"] for x in classification]
520
+ return jsonify({"labels": labels})
521
+
522
+
523
+ @app.route("/api/image", methods=["POST"])
524
+ @require_module("sd")
525
+ def api_image():
526
+ required_fields = {
527
+ "prompt": str,
528
+ }
529
+
530
+ optional_fields = {
531
+ "steps": 30,
532
+ "scale": 6,
533
+ "sampler": "DDIM",
534
+ "width": 512,
535
+ "height": 512,
536
+ "restore_faces": False,
537
+ "enable_hr": False,
538
+ "prompt_prefix": PROMPT_PREFIX,
539
+ "negative_prompt": NEGATIVE_PROMPT,
540
+ }
541
+
542
+ data = request.get_json()
543
+
544
+ # Check required fields
545
+ for field, field_type in required_fields.items():
546
+ if field not in data or not isinstance(data[field], field_type):
547
+ abort(400, f'"{field}" is required')
548
+
549
+ # Set optional fields to default values if not provided
550
+ for field, default_value in optional_fields.items():
551
+ type_match = (
552
+ (int, float)
553
+ if isinstance(default_value, (int, float))
554
+ else type(default_value)
555
+ )
556
+ if field not in data or not isinstance(data[field], type_match):
557
+ data[field] = default_value
558
+
559
+ try:
560
+ print("SD inputs:", data, sep="\n")
561
+ image = generate_image(data)
562
+ base64image = image_to_base64(image, quality=90)
563
+ return jsonify({"image": base64image})
564
+ except RuntimeError as e:
565
+ abort(400, str(e))
566
+
567
+
568
+ @app.route("/api/image/model", methods=["POST"])
569
+ @require_module("sd")
570
+ def api_image_model_set():
571
+ data = request.get_json()
572
+
573
+ if not sd_use_remote:
574
+ abort(400, "Changing model for local sd is not supported.")
575
+ if "model" not in data or not isinstance(data["model"], str):
576
+ abort(400, '"model" is required')
577
+
578
+ old_model = sd_remote.util_get_current_model()
579
+ sd_remote.util_set_model(data["model"], find_closest=False)
580
+ # sd_remote.util_set_model(data['model'])
581
+ sd_remote.util_wait_for_ready()
582
+ new_model = sd_remote.util_get_current_model()
583
+
584
+ return jsonify({"previous_model": old_model, "current_model": new_model})
585
+
586
+
587
+ @app.route("/api/image/model", methods=["GET"])
588
+ @require_module("sd")
589
+ def api_image_model_get():
590
+ model = sd_model
591
+
592
+ if sd_use_remote:
593
+ model = sd_remote.util_get_current_model()
594
+
595
+ return jsonify({"model": model})
596
+
597
+
598
+ @app.route("/api/image/models", methods=["GET"])
599
+ @require_module("sd")
600
+ def api_image_models():
601
+ models = [sd_model]
602
+
603
+ if sd_use_remote:
604
+ models = sd_remote.util_get_model_names()
605
+
606
+ return jsonify({"models": models})
607
+
608
+
609
+ @app.route("/api/image/samplers", methods=["GET"])
610
+ @require_module("sd")
611
+ def api_image_samplers():
612
+ samplers = ["Euler a"]
613
+
614
+ if sd_use_remote:
615
+ samplers = [sampler["name"] for sampler in sd_remote.get_samplers()]
616
+
617
+ return jsonify({"samplers": samplers})
618
+
619
+
620
+ @app.route("/api/modules", methods=["GET"])
621
+ def get_modules():
622
+ return jsonify({"modules": modules})
623
+
624
+
625
+ @app.route("/api/tts/speakers", methods=["GET"])
626
+ @require_module("silero-tts")
627
+ def tts_speakers():
628
+ voices = [
629
+ {
630
+ "name": speaker,
631
+ "voice_id": speaker,
632
+ "preview_url": f"{str(request.url_root)}api/tts/sample/{speaker}",
633
+ }
634
+ for speaker in tts_service.get_speakers()
635
+ ]
636
+ return jsonify(voices)
637
+
638
+
639
+ @app.route("/api/tts/generate", methods=["POST"])
640
+ @require_module("silero-tts")
641
+ def tts_generate():
642
+ voice = request.get_json()
643
+ if "text" not in voice or not isinstance(voice["text"], str):
644
+ abort(400, '"text" is required')
645
+ if "speaker" not in voice or not isinstance(voice["speaker"], str):
646
+ abort(400, '"speaker" is required')
647
+ # Remove asterisks
648
+ voice["text"] = voice["text"].replace("*", "")
649
+ try:
650
+ audio = tts_service.generate(voice["speaker"], voice["text"])
651
+ return send_file(audio, mimetype="audio/x-wav")
652
+ except Exception as e:
653
+ print(e)
654
+ abort(500, voice["speaker"])
655
+
656
+
657
+ @app.route("/api/tts/sample/<speaker>", methods=["GET"])
658
+ @require_module("silero-tts")
659
+ def tts_play_sample(speaker: str):
660
+ return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
661
+
662
+
663
+ @app.route("/api/edge-tts/list", methods=["GET"])
664
+ @require_module("edge-tts")
665
+ def edge_tts_list():
666
+ voices = edge.get_voices()
667
+ return jsonify(voices)
668
+
669
+
670
+ @app.route("/api/edge-tts/generate", methods=["POST"])
671
+ @require_module("edge-tts")
672
+ def edge_tts_generate():
673
+ data = request.get_json()
674
+ if "text" not in data or not isinstance(data["text"], str):
675
+ abort(400, '"text" is required')
676
+ if "voice" not in data or not isinstance(data["voice"], str):
677
+ abort(400, '"voice" is required')
678
+ if "rate" in data and isinstance(data['rate'], int):
679
+ rate = data['rate']
680
+ else:
681
+ rate = 0
682
+ # Remove asterisks
683
+ data["text"] = data["text"].replace("*", "")
684
+ try:
685
+ audio = edge.generate_audio(text=data["text"], voice=data["voice"], rate=rate)
686
+ return Response(audio, mimetype="audio/mpeg")
687
+ except Exception as e:
688
+ print(e)
689
+ abort(500, data["voice"])
690
+
691
+
692
+ @app.route("/api/chromadb", methods=["POST"])
693
+ @require_module("chromadb")
694
+ def chromadb_add_messages():
695
+ data = request.get_json()
696
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
697
+ abort(400, '"chat_id" is required')
698
+ if "messages" not in data or not isinstance(data["messages"], list):
699
+ abort(400, '"messages" is required')
700
+
701
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
702
+ collection = chromadb_client.get_or_create_collection(
703
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
704
+ )
705
+
706
+ documents = [m["content"] for m in data["messages"]]
707
+ ids = [m["id"] for m in data["messages"]]
708
+ metadatas = [
709
+ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
710
+ for m in data["messages"]
711
+ ]
712
+
713
+ collection.upsert(
714
+ ids=ids,
715
+ documents=documents,
716
+ metadatas=metadatas,
717
+ )
718
+
719
+ return jsonify({"count": len(ids)})
720
+
721
+
722
+ @app.route("/api/chromadb/purge", methods=["POST"])
723
+ @require_module("chromadb")
724
+ def chromadb_purge():
725
+ data = request.get_json()
726
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
727
+ abort(400, '"chat_id" is required')
728
+
729
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
730
+ collection = chromadb_client.get_or_create_collection(
731
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
732
+ )
733
+
734
+ count = collection.count()
735
+ collection.delete()
736
+ #Write deletion to persistent folder
737
+ chromadb_client.persist()
738
+ print("ChromaDB embeddings deleted", count)
739
+ return 'Ok', 200
740
+
741
+
742
+ @app.route("/api/chromadb/query", methods=["POST"])
743
+ @require_module("chromadb")
744
+ def chromadb_query():
745
+ data = request.get_json()
746
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
747
+ abort(400, '"chat_id" is required')
748
+ if "query" not in data or not isinstance(data["query"], str):
749
+ abort(400, '"query" is required')
750
+
751
+ if "n_results" not in data or not isinstance(data["n_results"], int):
752
+ n_results = 1
753
+ else:
754
+ n_results = data["n_results"]
755
+
756
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
757
+ collection = chromadb_client.get_or_create_collection(
758
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
759
+ )
760
+
761
+ n_results = min(collection.count(), n_results)
762
+ query_result = collection.query(
763
+ query_texts=[data["query"]],
764
+ n_results=n_results,
765
+ )
766
+
767
+ documents = query_result["documents"][0]
768
+ ids = query_result["ids"][0]
769
+ metadatas = query_result["metadatas"][0]
770
+ distances = query_result["distances"][0]
771
+
772
+ messages = [
773
+ {
774
+ "id": ids[i],
775
+ "date": metadatas[i]["date"],
776
+ "role": metadatas[i]["role"],
777
+ "meta": metadatas[i]["meta"],
778
+ "content": documents[i],
779
+ "distance": distances[i],
780
+ }
781
+ for i in range(len(ids))
782
+ ]
783
+
784
+ return jsonify(messages)
785
+
786
+
787
+ @app.route("/api/chromadb/export", methods=["POST"])
788
+ @require_module("chromadb")
789
+ def chromadb_export():
790
+ data = request.get_json()
791
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
792
+ abort(400, '"chat_id" is required')
793
+
794
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
795
+ collection = chromadb_client.get_or_create_collection(
796
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
797
+ )
798
+ collection_content = collection.get()
799
+ documents = collection_content.get('documents', [])
800
+ ids = collection_content.get('ids', [])
801
+ metadatas = collection_content.get('metadatas', [])
802
+
803
+ unsorted_content = [
804
+ {
805
+ "id": ids[i],
806
+ "metadata": metadatas[i],
807
+ "document": documents[i],
808
+ }
809
+ for i in range(len(ids))
810
+ ]
811
+
812
+ sorted_content = sorted(unsorted_content, key=lambda x: x['metadata']['date'])
813
+
814
+ export = {
815
+ "chat_id": data["chat_id"],
816
+ "content": sorted_content
817
+ }
818
+
819
+ return jsonify(export)
820
+
821
+ @app.route("/api/chromadb/import", methods=["POST"])
822
+ @require_module("chromadb")
823
+ def chromadb_import():
824
+ data = request.get_json()
825
+ content = data['content']
826
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
827
+ abort(400, '"chat_id" is required')
828
+
829
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
830
+ collection = chromadb_client.get_or_create_collection(
831
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
832
+ )
833
+
834
+ documents = [item['document'] for item in content]
835
+ metadatas = [item['metadata'] for item in content]
836
+ ids = [item['id'] for item in content]
837
+
838
+
839
+ collection.upsert(documents=documents, metadatas=metadatas, ids=ids)
840
+
841
+ return jsonify({"count": len(ids)})
842
+
843
+
844
+ app.run(host=host, port=port)
tts_edge.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import edge_tts
3
+ import asyncio
4
+
5
+
6
+ def get_voices():
7
+ voices = asyncio.run(edge_tts.list_voices())
8
+ return voices
9
+
10
+
11
+ async def _iterate_chunks(audio):
12
+ async for chunk in audio.stream():
13
+ if chunk["type"] == "audio":
14
+ yield chunk["data"]
15
+
16
+
17
+ async def _async_generator_to_list(async_gen):
18
+ result = []
19
+ async for item in async_gen:
20
+ result.append(item)
21
+ return result
22
+
23
+
24
+ def generate_audio(text: str, voice: str, rate: int) -> bytes:
25
+ sign = '+' if rate > 0 else '-'
26
+ rate = f'{sign}{abs(rate)}%'
27
+ audio = edge_tts.Communicate(text=text, voice=voice, rate=rate)
28
+ chunks = asyncio.run(_async_generator_to_list(_iterate_chunks(audio)))
29
+ buffer = io.BytesIO()
30
+
31
+ for chunk in chunks:
32
+ buffer.write(chunk)
33
+
34
+ return buffer.getvalue()