Salt
commited on
Commit
·
61200c0
1
Parent(s):
5eff76c
Update server.py
Browse files
server.py
CHANGED
@@ -97,8 +97,9 @@ parser.add_argument(
|
|
97 |
|
98 |
args = parser.parse_args()
|
99 |
|
100 |
-
port =
|
101 |
-
host = "0.0.0.0"
|
|
|
102 |
summarization_model = (
|
103 |
args.summarization_model
|
104 |
if args.summarization_model
|
@@ -109,23 +110,6 @@ classification_model = (
|
|
109 |
if args.classification_model
|
110 |
else DEFAULT_CLASSIFICATION_MODEL
|
111 |
)
|
112 |
-
captioning_model = (
|
113 |
-
args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
|
114 |
-
)
|
115 |
-
keyphrase_model = (
|
116 |
-
args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
|
117 |
-
)
|
118 |
-
prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
|
119 |
-
embedding_model = (
|
120 |
-
args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
|
121 |
-
)
|
122 |
-
|
123 |
-
sd_use_remote = False if args.sd_model else True
|
124 |
-
sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
|
125 |
-
sd_remote_host = args.sd_remote_host if args.sd_remote_host else DEFAULT_REMOTE_SD_HOST
|
126 |
-
sd_remote_port = args.sd_remote_port if args.sd_remote_port else DEFAULT_REMOTE_SD_PORT
|
127 |
-
sd_remote_ssl = args.sd_remote_ssl
|
128 |
-
sd_remote_auth = args.sd_remote_auth
|
129 |
|
130 |
modules = (
|
131 |
args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
|
@@ -142,18 +126,6 @@ device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu
|
|
142 |
device = torch.device(device_string)
|
143 |
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
|
144 |
|
145 |
-
if "caption" in modules:
|
146 |
-
print("Initializing an image captioning model...")
|
147 |
-
captioning_processor = AutoProcessor.from_pretrained(captioning_model)
|
148 |
-
if "blip" in captioning_model:
|
149 |
-
captioning_transformer = BlipForConditionalGeneration.from_pretrained(
|
150 |
-
captioning_model, torch_dtype=torch_dtype
|
151 |
-
).to(device)
|
152 |
-
else:
|
153 |
-
captioning_transformer = AutoModelForCausalLM.from_pretrained(
|
154 |
-
captioning_model, torch_dtype=torch_dtype
|
155 |
-
).to(device)
|
156 |
-
|
157 |
if "summarize" in modules:
|
158 |
print("Initializing a text summarization model...")
|
159 |
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
|
@@ -171,69 +143,6 @@ if "classify" in modules:
|
|
171 |
torch_dtype=torch_dtype,
|
172 |
)
|
173 |
|
174 |
-
if "keywords" in modules:
|
175 |
-
print("Initializing a keyword extraction pipeline...")
|
176 |
-
import pipelines as pipelines
|
177 |
-
|
178 |
-
keyphrase_pipe = pipelines.KeyphraseExtractionPipeline(keyphrase_model)
|
179 |
-
|
180 |
-
if "prompt" in modules:
|
181 |
-
print("Initializing a prompt generator")
|
182 |
-
gpt_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
183 |
-
gpt_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
184 |
-
gpt_model = AutoModelForCausalLM.from_pretrained(prompt_model)
|
185 |
-
prompt_generator = pipeline(
|
186 |
-
"text-generation", model=gpt_model, tokenizer=gpt_tokenizer
|
187 |
-
)
|
188 |
-
|
189 |
-
if "sd" in modules and not sd_use_remote:
|
190 |
-
from diffusers import StableDiffusionPipeline
|
191 |
-
from diffusers import EulerAncestralDiscreteScheduler
|
192 |
-
|
193 |
-
print("Initializing Stable Diffusion pipeline")
|
194 |
-
sd_device_string = (
|
195 |
-
"cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
|
196 |
-
)
|
197 |
-
sd_device = torch.device(sd_device_string)
|
198 |
-
sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
|
199 |
-
sd_pipe = StableDiffusionPipeline.from_pretrained(
|
200 |
-
sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
|
201 |
-
).to(sd_device)
|
202 |
-
sd_pipe.safety_checker = lambda images, clip_input: (images, False)
|
203 |
-
sd_pipe.enable_attention_slicing()
|
204 |
-
# pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
|
205 |
-
sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
206 |
-
sd_pipe.scheduler.config
|
207 |
-
)
|
208 |
-
elif "sd" in modules and sd_use_remote:
|
209 |
-
print("Initializing Stable Diffusion connection")
|
210 |
-
try:
|
211 |
-
sd_remote = webuiapi.WebUIApi(
|
212 |
-
host=sd_remote_host, port=sd_remote_port, use_https=sd_remote_ssl
|
213 |
-
)
|
214 |
-
if sd_remote_auth:
|
215 |
-
username, password = sd_remote_auth.split(":")
|
216 |
-
sd_remote.set_auth(username, password)
|
217 |
-
sd_remote.util_wait_for_ready()
|
218 |
-
except Exception as e:
|
219 |
-
# remote sd from modules
|
220 |
-
print(
|
221 |
-
f"{Fore.RED}{Style.BRIGHT}Could not connect to remote SD backend at http{'s' if sd_remote_ssl else ''}://{sd_remote_host}:{sd_remote_port}! Disabling SD module...{Style.RESET_ALL}"
|
222 |
-
)
|
223 |
-
modules.remove("sd")
|
224 |
-
|
225 |
-
if "tts" in modules:
|
226 |
-
if not os.path.exists(SILERO_SAMPLES_PATH):
|
227 |
-
os.makedirs(SILERO_SAMPLES_PATH)
|
228 |
-
print("Initializing Silero TTS server")
|
229 |
-
from silero_api_server import tts
|
230 |
-
|
231 |
-
tts_service = tts.SileroTtsService(SILERO_SAMPLES_PATH)
|
232 |
-
if len(os.listdir(SILERO_SAMPLES_PATH)) == 0:
|
233 |
-
print("Generating Silero TTS samples...")
|
234 |
-
tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
|
235 |
-
tts_service.generate_samples()
|
236 |
-
|
237 |
if "chromadb" in modules:
|
238 |
print("Initializing ChromaDB")
|
239 |
import chromadb
|
@@ -277,15 +186,6 @@ def classify_text(text: str) -> list:
|
|
277 |
return sorted(output, key=lambda x: x["score"], reverse=True)
|
278 |
|
279 |
|
280 |
-
def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
|
281 |
-
inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
|
282 |
-
device, torch_dtype
|
283 |
-
)
|
284 |
-
outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
|
285 |
-
caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
|
286 |
-
return caption
|
287 |
-
|
288 |
-
|
289 |
def summarize_chunks(text: str, params: dict) -> str:
|
290 |
try:
|
291 |
return summarize(text, params)
|
@@ -331,70 +231,6 @@ def normalize_string(input: str) -> str:
|
|
331 |
output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
|
332 |
return output
|
333 |
|
334 |
-
|
335 |
-
def extract_keywords(text: str) -> list:
|
336 |
-
punctuation = "(){}[]\n\r<>"
|
337 |
-
trans = str.maketrans(punctuation, " " * len(punctuation))
|
338 |
-
text = text.translate(trans)
|
339 |
-
text = normalize_string(text)
|
340 |
-
return list(keyphrase_pipe(text))
|
341 |
-
|
342 |
-
|
343 |
-
def generate_prompt(keywords: list, length: int = 100, num: int = 4) -> str:
|
344 |
-
prompt = ", ".join(keywords)
|
345 |
-
outs = prompt_generator(
|
346 |
-
prompt,
|
347 |
-
max_length=length,
|
348 |
-
num_return_sequences=num,
|
349 |
-
do_sample=True,
|
350 |
-
repetition_penalty=1.2,
|
351 |
-
temperature=0.7,
|
352 |
-
top_k=4,
|
353 |
-
early_stopping=True,
|
354 |
-
)
|
355 |
-
return [out["generated_text"] for out in outs]
|
356 |
-
|
357 |
-
|
358 |
-
def generate_image(data: dict) -> Image:
|
359 |
-
prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
|
360 |
-
|
361 |
-
if sd_use_remote:
|
362 |
-
image = sd_remote.txt2img(
|
363 |
-
prompt=prompt,
|
364 |
-
negative_prompt=data["negative_prompt"],
|
365 |
-
sampler_name=data["sampler"],
|
366 |
-
steps=data["steps"],
|
367 |
-
cfg_scale=data["scale"],
|
368 |
-
width=data["width"],
|
369 |
-
height=data["height"],
|
370 |
-
restore_faces=data["restore_faces"],
|
371 |
-
enable_hr=data["enable_hr"],
|
372 |
-
save_images=True,
|
373 |
-
send_images=True,
|
374 |
-
do_not_save_grid=False,
|
375 |
-
do_not_save_samples=False,
|
376 |
-
).image
|
377 |
-
else:
|
378 |
-
image = sd_pipe(
|
379 |
-
prompt=prompt,
|
380 |
-
negative_prompt=data["negative_prompt"],
|
381 |
-
num_inference_steps=data["steps"],
|
382 |
-
guidance_scale=data["scale"],
|
383 |
-
width=data["width"],
|
384 |
-
height=data["height"],
|
385 |
-
).images[0]
|
386 |
-
|
387 |
-
image.save("./debug.png")
|
388 |
-
return image
|
389 |
-
|
390 |
-
|
391 |
-
def image_to_base64(image: Image, quality: int = 75) -> str:
|
392 |
-
buffered = BytesIO()
|
393 |
-
image.save(buffered, format="JPEG", quality=quality)
|
394 |
-
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
395 |
-
return img_str
|
396 |
-
|
397 |
-
|
398 |
@app.before_request
|
399 |
# Request time measuring
|
400 |
def before_request():
|
@@ -527,144 +363,10 @@ def api_prompt():
|
|
527 |
print("Prompt output:", prompts, sep="\n")
|
528 |
return jsonify({"prompts": prompts})
|
529 |
|
530 |
-
|
531 |
-
@app.route("/api/image", methods=["POST"])
|
532 |
-
@require_module("sd")
|
533 |
-
def api_image():
|
534 |
-
required_fields = {
|
535 |
-
"prompt": str,
|
536 |
-
}
|
537 |
-
|
538 |
-
optional_fields = {
|
539 |
-
"steps": 30,
|
540 |
-
"scale": 6,
|
541 |
-
"sampler": "DDIM",
|
542 |
-
"width": 512,
|
543 |
-
"height": 512,
|
544 |
-
"restore_faces": False,
|
545 |
-
"enable_hr": False,
|
546 |
-
"prompt_prefix": PROMPT_PREFIX,
|
547 |
-
"negative_prompt": NEGATIVE_PROMPT,
|
548 |
-
}
|
549 |
-
|
550 |
-
data = request.get_json()
|
551 |
-
|
552 |
-
# Check required fields
|
553 |
-
for field, field_type in required_fields.items():
|
554 |
-
if field not in data or not isinstance(data[field], field_type):
|
555 |
-
abort(400, f'"{field}" is required')
|
556 |
-
|
557 |
-
# Set optional fields to default values if not provided
|
558 |
-
for field, default_value in optional_fields.items():
|
559 |
-
type_match = (
|
560 |
-
(int, float)
|
561 |
-
if isinstance(default_value, (int, float))
|
562 |
-
else type(default_value)
|
563 |
-
)
|
564 |
-
if field not in data or not isinstance(data[field], type_match):
|
565 |
-
data[field] = default_value
|
566 |
-
|
567 |
-
try:
|
568 |
-
print("SD inputs:", data, sep="\n")
|
569 |
-
image = generate_image(data)
|
570 |
-
base64image = image_to_base64(image, quality=90)
|
571 |
-
return jsonify({"image": base64image})
|
572 |
-
except RuntimeError as e:
|
573 |
-
abort(400, str(e))
|
574 |
-
|
575 |
-
|
576 |
-
@app.route("/api/image/model", methods=["POST"])
|
577 |
-
@require_module("sd")
|
578 |
-
def api_image_model_set():
|
579 |
-
data = request.get_json()
|
580 |
-
|
581 |
-
if not sd_use_remote:
|
582 |
-
abort(400, "Changing model for local sd is not supported.")
|
583 |
-
if "model" not in data or not isinstance(data["model"], str):
|
584 |
-
abort(400, '"model" is required')
|
585 |
-
|
586 |
-
old_model = sd_remote.util_get_current_model()
|
587 |
-
sd_remote.util_set_model(data["model"], find_closest=False)
|
588 |
-
# sd_remote.util_set_model(data['model'])
|
589 |
-
sd_remote.util_wait_for_ready()
|
590 |
-
new_model = sd_remote.util_get_current_model()
|
591 |
-
|
592 |
-
return jsonify({"previous_model": old_model, "current_model": new_model})
|
593 |
-
|
594 |
-
|
595 |
-
@app.route("/api/image/model", methods=["GET"])
|
596 |
-
@require_module("sd")
|
597 |
-
def api_image_model_get():
|
598 |
-
model = sd_model
|
599 |
-
|
600 |
-
if sd_use_remote:
|
601 |
-
model = sd_remote.util_get_current_model()
|
602 |
-
|
603 |
-
return jsonify({"model": model})
|
604 |
-
|
605 |
-
|
606 |
-
@app.route("/api/image/models", methods=["GET"])
|
607 |
-
@require_module("sd")
|
608 |
-
def api_image_models():
|
609 |
-
models = [sd_model]
|
610 |
-
|
611 |
-
if sd_use_remote:
|
612 |
-
models = sd_remote.util_get_model_names()
|
613 |
-
|
614 |
-
return jsonify({"models": models})
|
615 |
-
|
616 |
-
|
617 |
-
@app.route("/api/image/samplers", methods=["GET"])
|
618 |
-
@require_module("sd")
|
619 |
-
def api_image_samplers():
|
620 |
-
samplers = ["Euler a"]
|
621 |
-
|
622 |
-
if sd_use_remote:
|
623 |
-
samplers = [sampler["name"] for sampler in sd_remote.get_samplers()]
|
624 |
-
|
625 |
-
return jsonify({"samplers": samplers})
|
626 |
-
|
627 |
-
|
628 |
@app.route("/api/modules", methods=["GET"])
|
629 |
def get_modules():
|
630 |
return jsonify({"modules": modules})
|
631 |
|
632 |
-
|
633 |
-
@app.route("/api/tts/speakers", methods=["GET"])
|
634 |
-
def tts_speakers():
|
635 |
-
voices = [
|
636 |
-
{
|
637 |
-
"name": speaker,
|
638 |
-
"voice_id": speaker,
|
639 |
-
"preview_url": f"{str(request.url_root)}api/tts/sample/{speaker}",
|
640 |
-
}
|
641 |
-
for speaker in tts_service.get_speakers()
|
642 |
-
]
|
643 |
-
return jsonify(voices)
|
644 |
-
|
645 |
-
|
646 |
-
@app.route("/api/tts/generate", methods=["POST"])
|
647 |
-
def tts_generate():
|
648 |
-
voice = request.get_json()
|
649 |
-
if "text" not in voice or not isinstance(voice["text"], str):
|
650 |
-
abort(400, '"text" is required')
|
651 |
-
if "speaker" not in voice or not isinstance(voice["speaker"], str):
|
652 |
-
abort(400, '"speaker" is required')
|
653 |
-
# Remove asterisks
|
654 |
-
voice["text"] = voice["text"].replace("*", "")
|
655 |
-
try:
|
656 |
-
audio = tts_service.generate(voice["speaker"], voice["text"])
|
657 |
-
return send_file(audio, mimetype="audio/x-wav")
|
658 |
-
except Exception as e:
|
659 |
-
print(e)
|
660 |
-
abort(500, voice["speaker"])
|
661 |
-
|
662 |
-
|
663 |
-
@app.route("/api/tts/sample/<speaker>", methods=["GET"])
|
664 |
-
def tts_play_sample(speaker: str):
|
665 |
-
return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
|
666 |
-
|
667 |
-
|
668 |
@app.route("/api/chromadb", methods=["POST"])
|
669 |
@require_module("chromadb")
|
670 |
def chromadb_add_messages():
|
@@ -756,22 +458,4 @@ def chromadb_query():
|
|
756 |
|
757 |
return jsonify(messages)
|
758 |
|
759 |
-
|
760 |
-
if args.share:
|
761 |
-
from flask_cloudflared import _run_cloudflared
|
762 |
-
import inspect
|
763 |
-
|
764 |
-
sig = inspect.signature(_run_cloudflared)
|
765 |
-
sum = sum(
|
766 |
-
1
|
767 |
-
for param in sig.parameters.values()
|
768 |
-
if param.kind == param.POSITIONAL_OR_KEYWORD
|
769 |
-
)
|
770 |
-
if sum > 1:
|
771 |
-
metrics_port = randint(8100, 9000)
|
772 |
-
cloudflare = _run_cloudflared(port, metrics_port)
|
773 |
-
else:
|
774 |
-
cloudflare = _run_cloudflared(port)
|
775 |
-
print("Running on", cloudflare)
|
776 |
-
|
777 |
app.run(host=host, port=port)
|
|
|
97 |
|
98 |
args = parser.parse_args()
|
99 |
|
100 |
+
port = 7860
|
101 |
+
host = "0.0.0.0"
|
102 |
+
|
103 |
summarization_model = (
|
104 |
args.summarization_model
|
105 |
if args.summarization_model
|
|
|
110 |
if args.classification_model
|
111 |
else DEFAULT_CLASSIFICATION_MODEL
|
112 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
modules = (
|
115 |
args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
|
|
|
126 |
device = torch.device(device_string)
|
127 |
torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
if "summarize" in modules:
|
130 |
print("Initializing a text summarization model...")
|
131 |
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
|
|
|
143 |
torch_dtype=torch_dtype,
|
144 |
)
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
if "chromadb" in modules:
|
147 |
print("Initializing ChromaDB")
|
148 |
import chromadb
|
|
|
186 |
return sorted(output, key=lambda x: x["score"], reverse=True)
|
187 |
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
def summarize_chunks(text: str, params: dict) -> str:
|
190 |
try:
|
191 |
return summarize(text, params)
|
|
|
231 |
output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
|
232 |
return output
|
233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
@app.before_request
|
235 |
# Request time measuring
|
236 |
def before_request():
|
|
|
363 |
print("Prompt output:", prompts, sep="\n")
|
364 |
return jsonify({"prompts": prompts})
|
365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
@app.route("/api/modules", methods=["GET"])
|
367 |
def get_modules():
|
368 |
return jsonify({"modules": modules})
|
369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
@app.route("/api/chromadb", methods=["POST"])
|
371 |
@require_module("chromadb")
|
372 |
def chromadb_add_messages():
|
|
|
458 |
|
459 |
return jsonify(messages)
|
460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
app.run(host=host, port=port)
|