Salt commited on
Commit
61200c0
·
1 Parent(s): 5eff76c

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +3 -319
server.py CHANGED
@@ -97,8 +97,9 @@ parser.add_argument(
97
 
98
  args = parser.parse_args()
99
 
100
- port = args.port if args.port else 5100
101
- host = "0.0.0.0" if args.listen else "localhost"
 
102
  summarization_model = (
103
  args.summarization_model
104
  if args.summarization_model
@@ -109,23 +110,6 @@ classification_model = (
109
  if args.classification_model
110
  else DEFAULT_CLASSIFICATION_MODEL
111
  )
112
- captioning_model = (
113
- args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
114
- )
115
- keyphrase_model = (
116
- args.keyphrase_model if args.keyphrase_model else DEFAULT_KEYPHRASE_MODEL
117
- )
118
- prompt_model = args.prompt_model if args.prompt_model else DEFAULT_PROMPT_MODEL
119
- embedding_model = (
120
- args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
121
- )
122
-
123
- sd_use_remote = False if args.sd_model else True
124
- sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
125
- sd_remote_host = args.sd_remote_host if args.sd_remote_host else DEFAULT_REMOTE_SD_HOST
126
- sd_remote_port = args.sd_remote_port if args.sd_remote_port else DEFAULT_REMOTE_SD_PORT
127
- sd_remote_ssl = args.sd_remote_ssl
128
- sd_remote_auth = args.sd_remote_auth
129
 
130
  modules = (
131
  args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
@@ -142,18 +126,6 @@ device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu
142
  device = torch.device(device_string)
143
  torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
144
 
145
- if "caption" in modules:
146
- print("Initializing an image captioning model...")
147
- captioning_processor = AutoProcessor.from_pretrained(captioning_model)
148
- if "blip" in captioning_model:
149
- captioning_transformer = BlipForConditionalGeneration.from_pretrained(
150
- captioning_model, torch_dtype=torch_dtype
151
- ).to(device)
152
- else:
153
- captioning_transformer = AutoModelForCausalLM.from_pretrained(
154
- captioning_model, torch_dtype=torch_dtype
155
- ).to(device)
156
-
157
  if "summarize" in modules:
158
  print("Initializing a text summarization model...")
159
  summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
@@ -171,69 +143,6 @@ if "classify" in modules:
171
  torch_dtype=torch_dtype,
172
  )
173
 
174
- if "keywords" in modules:
175
- print("Initializing a keyword extraction pipeline...")
176
- import pipelines as pipelines
177
-
178
- keyphrase_pipe = pipelines.KeyphraseExtractionPipeline(keyphrase_model)
179
-
180
- if "prompt" in modules:
181
- print("Initializing a prompt generator")
182
- gpt_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
183
- gpt_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
184
- gpt_model = AutoModelForCausalLM.from_pretrained(prompt_model)
185
- prompt_generator = pipeline(
186
- "text-generation", model=gpt_model, tokenizer=gpt_tokenizer
187
- )
188
-
189
- if "sd" in modules and not sd_use_remote:
190
- from diffusers import StableDiffusionPipeline
191
- from diffusers import EulerAncestralDiscreteScheduler
192
-
193
- print("Initializing Stable Diffusion pipeline")
194
- sd_device_string = (
195
- "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
196
- )
197
- sd_device = torch.device(sd_device_string)
198
- sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
199
- sd_pipe = StableDiffusionPipeline.from_pretrained(
200
- sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
201
- ).to(sd_device)
202
- sd_pipe.safety_checker = lambda images, clip_input: (images, False)
203
- sd_pipe.enable_attention_slicing()
204
- # pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
205
- sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
206
- sd_pipe.scheduler.config
207
- )
208
- elif "sd" in modules and sd_use_remote:
209
- print("Initializing Stable Diffusion connection")
210
- try:
211
- sd_remote = webuiapi.WebUIApi(
212
- host=sd_remote_host, port=sd_remote_port, use_https=sd_remote_ssl
213
- )
214
- if sd_remote_auth:
215
- username, password = sd_remote_auth.split(":")
216
- sd_remote.set_auth(username, password)
217
- sd_remote.util_wait_for_ready()
218
- except Exception as e:
219
- # remote sd from modules
220
- print(
221
- f"{Fore.RED}{Style.BRIGHT}Could not connect to remote SD backend at http{'s' if sd_remote_ssl else ''}://{sd_remote_host}:{sd_remote_port}! Disabling SD module...{Style.RESET_ALL}"
222
- )
223
- modules.remove("sd")
224
-
225
- if "tts" in modules:
226
- if not os.path.exists(SILERO_SAMPLES_PATH):
227
- os.makedirs(SILERO_SAMPLES_PATH)
228
- print("Initializing Silero TTS server")
229
- from silero_api_server import tts
230
-
231
- tts_service = tts.SileroTtsService(SILERO_SAMPLES_PATH)
232
- if len(os.listdir(SILERO_SAMPLES_PATH)) == 0:
233
- print("Generating Silero TTS samples...")
234
- tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
235
- tts_service.generate_samples()
236
-
237
  if "chromadb" in modules:
238
  print("Initializing ChromaDB")
239
  import chromadb
@@ -277,15 +186,6 @@ def classify_text(text: str) -> list:
277
  return sorted(output, key=lambda x: x["score"], reverse=True)
278
 
279
 
280
- def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
281
- inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
282
- device, torch_dtype
283
- )
284
- outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
285
- caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
286
- return caption
287
-
288
-
289
  def summarize_chunks(text: str, params: dict) -> str:
290
  try:
291
  return summarize(text, params)
@@ -331,70 +231,6 @@ def normalize_string(input: str) -> str:
331
  output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
332
  return output
333
 
334
-
335
- def extract_keywords(text: str) -> list:
336
- punctuation = "(){}[]\n\r<>"
337
- trans = str.maketrans(punctuation, " " * len(punctuation))
338
- text = text.translate(trans)
339
- text = normalize_string(text)
340
- return list(keyphrase_pipe(text))
341
-
342
-
343
- def generate_prompt(keywords: list, length: int = 100, num: int = 4) -> str:
344
- prompt = ", ".join(keywords)
345
- outs = prompt_generator(
346
- prompt,
347
- max_length=length,
348
- num_return_sequences=num,
349
- do_sample=True,
350
- repetition_penalty=1.2,
351
- temperature=0.7,
352
- top_k=4,
353
- early_stopping=True,
354
- )
355
- return [out["generated_text"] for out in outs]
356
-
357
-
358
- def generate_image(data: dict) -> Image:
359
- prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
360
-
361
- if sd_use_remote:
362
- image = sd_remote.txt2img(
363
- prompt=prompt,
364
- negative_prompt=data["negative_prompt"],
365
- sampler_name=data["sampler"],
366
- steps=data["steps"],
367
- cfg_scale=data["scale"],
368
- width=data["width"],
369
- height=data["height"],
370
- restore_faces=data["restore_faces"],
371
- enable_hr=data["enable_hr"],
372
- save_images=True,
373
- send_images=True,
374
- do_not_save_grid=False,
375
- do_not_save_samples=False,
376
- ).image
377
- else:
378
- image = sd_pipe(
379
- prompt=prompt,
380
- negative_prompt=data["negative_prompt"],
381
- num_inference_steps=data["steps"],
382
- guidance_scale=data["scale"],
383
- width=data["width"],
384
- height=data["height"],
385
- ).images[0]
386
-
387
- image.save("./debug.png")
388
- return image
389
-
390
-
391
- def image_to_base64(image: Image, quality: int = 75) -> str:
392
- buffered = BytesIO()
393
- image.save(buffered, format="JPEG", quality=quality)
394
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
395
- return img_str
396
-
397
-
398
  @app.before_request
399
  # Request time measuring
400
  def before_request():
@@ -527,144 +363,10 @@ def api_prompt():
527
  print("Prompt output:", prompts, sep="\n")
528
  return jsonify({"prompts": prompts})
529
 
530
-
531
- @app.route("/api/image", methods=["POST"])
532
- @require_module("sd")
533
- def api_image():
534
- required_fields = {
535
- "prompt": str,
536
- }
537
-
538
- optional_fields = {
539
- "steps": 30,
540
- "scale": 6,
541
- "sampler": "DDIM",
542
- "width": 512,
543
- "height": 512,
544
- "restore_faces": False,
545
- "enable_hr": False,
546
- "prompt_prefix": PROMPT_PREFIX,
547
- "negative_prompt": NEGATIVE_PROMPT,
548
- }
549
-
550
- data = request.get_json()
551
-
552
- # Check required fields
553
- for field, field_type in required_fields.items():
554
- if field not in data or not isinstance(data[field], field_type):
555
- abort(400, f'"{field}" is required')
556
-
557
- # Set optional fields to default values if not provided
558
- for field, default_value in optional_fields.items():
559
- type_match = (
560
- (int, float)
561
- if isinstance(default_value, (int, float))
562
- else type(default_value)
563
- )
564
- if field not in data or not isinstance(data[field], type_match):
565
- data[field] = default_value
566
-
567
- try:
568
- print("SD inputs:", data, sep="\n")
569
- image = generate_image(data)
570
- base64image = image_to_base64(image, quality=90)
571
- return jsonify({"image": base64image})
572
- except RuntimeError as e:
573
- abort(400, str(e))
574
-
575
-
576
- @app.route("/api/image/model", methods=["POST"])
577
- @require_module("sd")
578
- def api_image_model_set():
579
- data = request.get_json()
580
-
581
- if not sd_use_remote:
582
- abort(400, "Changing model for local sd is not supported.")
583
- if "model" not in data or not isinstance(data["model"], str):
584
- abort(400, '"model" is required')
585
-
586
- old_model = sd_remote.util_get_current_model()
587
- sd_remote.util_set_model(data["model"], find_closest=False)
588
- # sd_remote.util_set_model(data['model'])
589
- sd_remote.util_wait_for_ready()
590
- new_model = sd_remote.util_get_current_model()
591
-
592
- return jsonify({"previous_model": old_model, "current_model": new_model})
593
-
594
-
595
- @app.route("/api/image/model", methods=["GET"])
596
- @require_module("sd")
597
- def api_image_model_get():
598
- model = sd_model
599
-
600
- if sd_use_remote:
601
- model = sd_remote.util_get_current_model()
602
-
603
- return jsonify({"model": model})
604
-
605
-
606
- @app.route("/api/image/models", methods=["GET"])
607
- @require_module("sd")
608
- def api_image_models():
609
- models = [sd_model]
610
-
611
- if sd_use_remote:
612
- models = sd_remote.util_get_model_names()
613
-
614
- return jsonify({"models": models})
615
-
616
-
617
- @app.route("/api/image/samplers", methods=["GET"])
618
- @require_module("sd")
619
- def api_image_samplers():
620
- samplers = ["Euler a"]
621
-
622
- if sd_use_remote:
623
- samplers = [sampler["name"] for sampler in sd_remote.get_samplers()]
624
-
625
- return jsonify({"samplers": samplers})
626
-
627
-
628
  @app.route("/api/modules", methods=["GET"])
629
  def get_modules():
630
  return jsonify({"modules": modules})
631
 
632
-
633
- @app.route("/api/tts/speakers", methods=["GET"])
634
- def tts_speakers():
635
- voices = [
636
- {
637
- "name": speaker,
638
- "voice_id": speaker,
639
- "preview_url": f"{str(request.url_root)}api/tts/sample/{speaker}",
640
- }
641
- for speaker in tts_service.get_speakers()
642
- ]
643
- return jsonify(voices)
644
-
645
-
646
- @app.route("/api/tts/generate", methods=["POST"])
647
- def tts_generate():
648
- voice = request.get_json()
649
- if "text" not in voice or not isinstance(voice["text"], str):
650
- abort(400, '"text" is required')
651
- if "speaker" not in voice or not isinstance(voice["speaker"], str):
652
- abort(400, '"speaker" is required')
653
- # Remove asterisks
654
- voice["text"] = voice["text"].replace("*", "")
655
- try:
656
- audio = tts_service.generate(voice["speaker"], voice["text"])
657
- return send_file(audio, mimetype="audio/x-wav")
658
- except Exception as e:
659
- print(e)
660
- abort(500, voice["speaker"])
661
-
662
-
663
- @app.route("/api/tts/sample/<speaker>", methods=["GET"])
664
- def tts_play_sample(speaker: str):
665
- return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
666
-
667
-
668
  @app.route("/api/chromadb", methods=["POST"])
669
  @require_module("chromadb")
670
  def chromadb_add_messages():
@@ -756,22 +458,4 @@ def chromadb_query():
756
 
757
  return jsonify(messages)
758
 
759
-
760
- if args.share:
761
- from flask_cloudflared import _run_cloudflared
762
- import inspect
763
-
764
- sig = inspect.signature(_run_cloudflared)
765
- sum = sum(
766
- 1
767
- for param in sig.parameters.values()
768
- if param.kind == param.POSITIONAL_OR_KEYWORD
769
- )
770
- if sum > 1:
771
- metrics_port = randint(8100, 9000)
772
- cloudflare = _run_cloudflared(port, metrics_port)
773
- else:
774
- cloudflare = _run_cloudflared(port)
775
- print("Running on", cloudflare)
776
-
777
  app.run(host=host, port=port)
 
97
 
98
  args = parser.parse_args()
99
 
100
+ port = 7860
101
+ host = "0.0.0.0"
102
+
103
  summarization_model = (
104
  args.summarization_model
105
  if args.summarization_model
 
110
  if args.classification_model
111
  else DEFAULT_CLASSIFICATION_MODEL
112
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  modules = (
115
  args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
 
126
  device = torch.device(device_string)
127
  torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
128
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  if "summarize" in modules:
130
  print("Initializing a text summarization model...")
131
  summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
 
143
  torch_dtype=torch_dtype,
144
  )
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  if "chromadb" in modules:
147
  print("Initializing ChromaDB")
148
  import chromadb
 
186
  return sorted(output, key=lambda x: x["score"], reverse=True)
187
 
188
 
 
 
 
 
 
 
 
 
 
189
  def summarize_chunks(text: str, params: dict) -> str:
190
  try:
191
  return summarize(text, params)
 
231
  output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
232
  return output
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  @app.before_request
235
  # Request time measuring
236
  def before_request():
 
363
  print("Prompt output:", prompts, sep="\n")
364
  return jsonify({"prompts": prompts})
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  @app.route("/api/modules", methods=["GET"])
367
  def get_modules():
368
  return jsonify({"modules": modules})
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  @app.route("/api/chromadb", methods=["POST"])
371
  @require_module("chromadb")
372
  def chromadb_add_messages():
 
458
 
459
  return jsonify(messages)
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  app.run(host=host, port=port)