openfree commited on
Commit
9affa6d
ยท
verified ยท
1 Parent(s): c68a920

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -37
app.py CHANGED
@@ -30,6 +30,7 @@ import subprocess
30
  import pytesseract
31
  from pdf2image import convert_from_path
32
  import queue # ์ถ”๊ฐ€: queue.Empty ์˜ˆ์™ธ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด
 
33
 
34
  # -------------------- ์ถ”๊ฐ€: PDF to Markdown ๋ณ€ํ™˜ ๊ด€๋ จ import --------------------
35
  try:
@@ -545,10 +546,15 @@ def clear_cuda_memory():
545
  @spaces.GPU
546
  def load_model():
547
  try:
 
 
 
548
  loaded_model = AutoModelForCausalLM.from_pretrained(
549
  MODEL_ID,
550
  torch_dtype=torch.bfloat16,
551
  device_map="auto",
 
 
552
  )
553
  return loaded_model
554
  except Exception as e:
@@ -628,19 +634,22 @@ def stream_chat(
628
  if len(history) > max_history_length:
629
  history = history[-max_history_length:]
630
 
 
 
631
  try:
632
  relevant_contexts = find_relevant_context(message)
633
- wiki_context = "\n\n๊ด€๋ จ ์œ„ํ‚คํ”ผ๋””์•„ ์ •๋ณด:\n"
634
- for ctx in relevant_contexts:
635
- wiki_context += (
636
- f"Q: {ctx['question']}\n"
637
- f"A: {ctx['answer']}\n"
638
- f"์œ ์‚ฌ๋„: {ctx['similarity']:.3f}\n\n"
639
- )
 
640
  except Exception as e:
641
  print(f"์ปจํ…์ŠคํŠธ ๊ฒ€์ƒ‰ ์˜ค๋ฅ˜: {str(e)}")
642
- wiki_context = ""
643
 
 
644
  conversation = []
645
  for prompt, answer in history:
646
  conversation.extend([
@@ -648,43 +657,61 @@ def stream_chat(
648
  {"role": "assistant", "content": answer}
649
  ])
650
 
651
- final_message = file_context + wiki_context + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
 
 
 
 
 
 
 
 
652
  conversation.append({"role": "user", "content": final_message})
653
 
 
654
  input_ids_str = build_prompt(conversation)
655
- # ๋จผ์ € 6000 ํ† ํฐ ์ด๋‚ด๋กœ ์ž๋ฅด๊ธฐ
656
- input_ids_str = _truncate_tokens_for_context(input_ids_str, 6000)
657
-
658
- inputs = tokenizer(input_ids_str, return_tensors="pt").to("cuda")
659
  max_context = 8192
660
- input_length = inputs["input_ids"].shape[1]
661
- remaining = max_context - input_length
662
-
663
- min_generation = 128
664
- # ๋งŒ์•ฝ ๋‚จ์€ ํ† ํฐ ์ˆ˜๊ฐ€ min_generation๋ณด๋‹ค ์ ์œผ๋ฉด ์ž…๋ ฅ์„ ์ถ”๊ฐ€๋กœ ์ž๋ฆ…๋‹ˆ๋‹ค.
665
- if remaining < min_generation:
 
 
666
  new_desired_input_length = max_context - min_generation
667
- if new_desired_input_length < 1:
668
- new_desired_input_length = 1
669
- print(f"[์ฃผ์˜] ์ž…๋ ฅ์ด ๋„ˆ๋ฌด ๊ธธ์–ด input_length={input_length} -> {new_desired_input_length}๋กœ ์žฌ์กฐ์ •")
670
- input_ids_str = _truncate_tokens_for_context(input_ids_str, new_desired_input_length)
671
- inputs = tokenizer(input_ids_str, return_tensors="pt").to("cuda")
672
- input_length = inputs["input_ids"].shape[1]
673
- remaining = max_context - input_length
674
-
675
- # max_new_tokens๊ฐ€ ์Œ์ˆ˜๊ฐ€ ๋˜์ง€ ์•Š๋„๋ก ๋ณด์ •
676
- if remaining < 1:
677
- remaining = 1
 
 
 
 
 
 
 
678
  if remaining < max_new_tokens:
679
- print(f"[์ฃผ์˜] ์ž…๋ ฅ ํ† ํฐ์ด ๋งŽ์•„ max_new_tokens={max_new_tokens} -> {remaining}๋กœ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค.")
680
  max_new_tokens = remaining
681
 
682
  print(f"์ž…๋ ฅ ํ…์„œ ์ƒ์„ฑ ํ›„ CUDA ๋ฉ”๋ชจ๋ฆฌ: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
683
 
 
684
  streamer = TextIteratorStreamer(
685
- tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True
686
  )
687
 
 
688
  generate_kwargs = dict(
689
  **inputs,
690
  streamer=streamer,
@@ -694,23 +721,51 @@ def stream_chat(
694
  max_new_tokens=max_new_tokens,
695
  do_sample=True,
696
  temperature=temperature,
697
- eos_token_id=255001,
698
  )
699
 
 
700
  clear_cuda_memory()
701
 
 
702
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
703
  thread.start()
704
 
 
705
  buffer = ""
 
 
 
706
  try:
707
  for new_text in streamer:
708
  buffer += new_text
 
 
 
 
 
 
 
 
 
 
 
709
  yield "", history + [[message, buffer]]
710
- except queue.Empty:
711
- print("Streamer timed out. ์ตœ์ข… ์‘๋‹ต์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.")
 
 
 
 
 
 
712
  yield "", history + [[message, buffer]]
713
-
 
 
 
 
 
714
  clear_cuda_memory()
715
 
716
  except Exception as e:
@@ -825,6 +880,10 @@ def create_demo():
825
  )
826
 
827
  file_upload.change(
 
 
 
 
828
  fn=init_msg,
829
  outputs=msg,
830
  queue=False
@@ -846,4 +905,4 @@ def create_demo():
846
 
847
  if __name__ == "__main__":
848
  demo = create_demo()
849
- demo.launch()
 
30
  import pytesseract
31
  from pdf2image import convert_from_path
32
  import queue # ์ถ”๊ฐ€: queue.Empty ์˜ˆ์™ธ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด
33
+ import time # ์ถ”๊ฐ€: ์ŠคํŠธ๋ฆฌ๋ฐ ํƒ€์ด๋ฐ์„ ์œ„ํ•ด
34
 
35
  # -------------------- ์ถ”๊ฐ€: PDF to Markdown ๋ณ€ํ™˜ ๊ด€๋ จ import --------------------
36
  try:
 
546
  @spaces.GPU
547
  def load_model():
548
  try:
549
+ # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ๋จผ์ € ์ˆ˜ํ–‰
550
+ clear_cuda_memory()
551
+
552
  loaded_model = AutoModelForCausalLM.from_pretrained(
553
  MODEL_ID,
554
  torch_dtype=torch.bfloat16,
555
  device_map="auto",
556
+ # ๋‚ฎ์€ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ์„ ์œ„ํ•œ ์„ค์ • ์ถ”๊ฐ€
557
+ low_cpu_mem_usage=True,
558
  )
559
  return loaded_model
560
  except Exception as e:
 
634
  if len(history) > max_history_length:
635
  history = history[-max_history_length:]
636
 
637
+ # ์œ„ํ‚คํ”ผ๋””์•„ ์ปจํ…์ŠคํŠธ ๊ฒ€์ƒ‰
638
+ wiki_context = ""
639
  try:
640
  relevant_contexts = find_relevant_context(message)
641
+ if relevant_contexts: # ๊ฒฐ๊ณผ๊ฐ€ ์žˆ์„ ๊ฒฝ์šฐ๋งŒ ์ถ”๊ฐ€
642
+ wiki_context = "\n\n๊ด€๋ จ ์œ„ํ‚คํ”ผ๋””์•„ ์ •๋ณด:\n"
643
+ for ctx in relevant_contexts:
644
+ wiki_context += (
645
+ f"Q: {ctx['question']}\n"
646
+ f"A: {ctx['answer']}\n"
647
+ f"์œ ์‚ฌ๋„: {ctx['similarity']:.3f}\n\n"
648
+ )
649
  except Exception as e:
650
  print(f"์ปจํ…์ŠคํŠธ ๊ฒ€์ƒ‰ ์˜ค๋ฅ˜: {str(e)}")
 
651
 
652
+ # ๋Œ€ํ™” ๋‚ด์—ญ ๊ตฌ์„ฑ
653
  conversation = []
654
  for prompt, answer in history:
655
  conversation.extend([
 
657
  {"role": "assistant", "content": answer}
658
  ])
659
 
660
+ # ์ตœ์ข… ๋ฉ”์‹œ์ง€ ๊ตฌ์„ฑ
661
+ final_message = message
662
+ if file_context:
663
+ final_message = file_context + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
664
+ if wiki_context:
665
+ final_message = wiki_context + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
666
+ if file_context and wiki_context:
667
+ final_message = file_context + wiki_context + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
668
+
669
  conversation.append({"role": "user", "content": final_message})
670
 
671
+ # ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ ๋ฐ ํ† ํฐํ™”
672
  input_ids_str = build_prompt(conversation)
673
+
674
+ # ๋จผ์ € ์ปจํ…์ŠคํŠธ ๊ธธ์ด ํ™•์ธ ๋ฐ ์ œํ•œ
 
 
675
  max_context = 8192
676
+ tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
677
+ input_length = tokenized_input["input_ids"].shape[1]
678
+
679
+ # ์ปจํ…์ŠคํŠธ๊ฐ€ ๋„ˆ๋ฌด ๊ธธ๋ฉด ์ž๋ฅด๊ธฐ
680
+ if input_length > max_context - max_new_tokens:
681
+ print(f"์ž…๋ ฅ์ด ๋„ˆ๋ฌด ๊น๋‹ˆ๋‹ค: {input_length} ํ† ํฐ. ์ž๋ฅด๋Š” ์ค‘...")
682
+ # ์ตœ์†Œ ์ƒ์„ฑ ํ† ํฐ ์ˆ˜ ํ™•๋ณด
683
+ min_generation = min(256, max_new_tokens)
684
  new_desired_input_length = max_context - min_generation
685
+
686
+ # ์ž…๋ ฅ ํ…์ŠคํŠธ๋ฅผ ํ† ํฐ ๋‹จ์œ„๋กœ ์ž๋ฅด๊ธฐ
687
+ tokens = tokenizer.encode(input_ids_str)
688
+ if len(tokens) > new_desired_input_length:
689
+ tokens = tokens[-new_desired_input_length:]
690
+ input_ids_str = tokenizer.decode(tokens)
691
+
692
+ # ๋‹ค์‹œ ํ† ํฐํ™”
693
+ tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
694
+ input_length = tokenized_input["input_ids"].shape[1]
695
+
696
+ print(f"์ตœ์ข… ์ž…๋ ฅ ๊ธธ์ด: {input_length} ํ† ํฐ")
697
+
698
+ # CUDA๋กœ ์ž…๋ ฅ ์ด๋™
699
+ inputs = tokenized_input.to("cuda")
700
+
701
+ # ๋‚จ์€ ํ† ํฐ ์ˆ˜ ๊ณ„์‚ฐ ๋ฐ max_new_tokens ์กฐ์ •
702
+ remaining = max_context - input_length
703
  if remaining < max_new_tokens:
704
+ print(f"max_new_tokens ์กฐ์ •: {max_new_tokens} -> {remaining}")
705
  max_new_tokens = remaining
706
 
707
  print(f"์ž…๋ ฅ ํ…์„œ ์ƒ์„ฑ ํ›„ CUDA ๋ฉ”๋ชจ๋ฆฌ: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
708
 
709
+ # ์ŠคํŠธ๋ฆฌ๋จธ ์„ค์ •
710
  streamer = TextIteratorStreamer(
711
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
712
  )
713
 
714
+ # ์ƒ์„ฑ ๋งค๊ฐœ๋ณ€์ˆ˜ ์„ค์ •
715
  generate_kwargs = dict(
716
  **inputs,
717
  streamer=streamer,
 
721
  max_new_tokens=max_new_tokens,
722
  do_sample=True,
723
  temperature=temperature,
724
+ eos_token_id=tokenizer.eos_token_id, # ๋ช…์‹œ์  EOS ํ† ํฐ ์ง€์ •
725
  )
726
 
727
+ # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
728
  clear_cuda_memory()
729
 
730
+ # ๋ณ„๋„ ์Šค๋ ˆ๋“œ์—์„œ ์ƒ์„ฑ ์‹คํ–‰
731
  thread = Thread(target=model.generate, kwargs=generate_kwargs)
732
  thread.start()
733
 
734
+ # ์‘๋‹ต ์ŠคํŠธ๋ฆฌ๋ฐ
735
  buffer = ""
736
+ partial_message = ""
737
+ last_yield_time = time.time()
738
+
739
  try:
740
  for new_text in streamer:
741
  buffer += new_text
742
+ partial_message += new_text
743
+
744
+ # ์ผ์ • ์‹œ๊ฐ„๋งˆ๋‹ค ๋˜๋Š” ํ…์ŠคํŠธ๊ฐ€ ์Œ“์ผ ๋•Œ๋งˆ๋‹ค ๊ฒฐ๊ณผ ์—…๋ฐ์ดํŠธ
745
+ current_time = time.time()
746
+ if current_time - last_yield_time > 0.1 or len(partial_message) > 20:
747
+ yield "", history + [[message, buffer]]
748
+ partial_message = ""
749
+ last_yield_time = current_time
750
+
751
+ # ๋งˆ์ง€๋ง‰ ์‘๋‹ต ํ™•์ธ
752
+ if buffer:
753
  yield "", history + [[message, buffer]]
754
+
755
+ # ๋Œ€ํ™” ๊ธฐ๋ก์— ์ €์žฅ
756
+ chat_history.add_conversation(message, buffer)
757
+
758
+ except Exception as e:
759
+ print(f"์ŠคํŠธ๋ฆฌ๋ฐ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
760
+ if not buffer: # ๋ฒ„ํผ๊ฐ€ ๋น„์–ด์žˆ์œผ๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ํ‘œ์‹œ
761
+ buffer = f"์‘๋‹ต ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
762
  yield "", history + [[message, buffer]]
763
+
764
+ # ์Šค๋ ˆ๋“œ๊ฐ€ ์—ฌ์ „ํžˆ ์‹คํ–‰ ์ค‘์ด๋ฉด ์ข…๋ฃŒ ๋Œ€๊ธฐ
765
+ if thread.is_alive():
766
+ thread.join(timeout=5.0)
767
+
768
+ # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
769
  clear_cuda_memory()
770
 
771
  except Exception as e:
 
880
  )
881
 
882
  file_upload.change(
883
+ fn=lambda: ("์ฒ˜๋ฆฌ ์ค‘...", [["์‹œ์Šคํ…œ", "ํŒŒ์ผ์„ ๋ถ„์„ ์ค‘์ž…๋‹ˆ๋‹ค. ์ž ์‹œ๋งŒ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”..."]]),
884
+ outputs=[msg, chatbot],
885
+ queue=False
886
+ ).then(
887
  fn=init_msg,
888
  outputs=msg,
889
  queue=False
 
905
 
906
  if __name__ == "__main__":
907
  demo = create_demo()
908
+ demo.launch()