pseudotensor commited on
Commit
ac63b1e
·
1 Parent(s): 7a7ff47

Update with h2oGPT hash cf3886c550581e34d9f05d69d2e3438b2a46d7b2

Browse files
Files changed (1) hide show
  1. generate.py +46 -38
generate.py CHANGED
@@ -5,6 +5,8 @@ import traceback
5
  import typing
6
  from threading import Thread
7
 
 
 
8
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
9
 
10
  SEED = 1236
@@ -809,46 +811,52 @@ def evaluate(
809
  )
810
 
811
  with torch.no_grad():
812
- # decoded tokenized prompt can deviate from prompt due to special characters
813
- inputs_decoded = decoder(input_ids[0])
814
- inputs_decoded_raw = decoder_raw(input_ids[0])
815
- if inputs_decoded == prompt:
816
- # normal
817
- pass
818
- elif inputs_decoded.lstrip() == prompt.lstrip():
819
- # sometimes extra space in front, make prompt same for prompt removal
820
- prompt = inputs_decoded
821
- elif inputs_decoded_raw == prompt:
822
- # some models specify special tokens that are part of normal prompt, so can't skip them
823
- inputs_decoded_raw = inputs_decoded
824
- decoder = decoder_raw
825
- else:
826
- print("WARNING: Special characters in prompt", flush=True)
827
- decoded_output = None
828
- if stream_output:
829
- skip_prompt = False
830
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
831
- gen_kwargs.update(dict(streamer=streamer))
832
- target_func = generate_with_exceptions
833
- target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
834
- raise_generate_gpu_exceptions, **gen_kwargs)
835
- thread = Thread(target=target)
836
- thread.start()
837
- outputs = ""
838
- for new_text in streamer:
839
- outputs += new_text
 
 
 
 
 
 
 
 
 
 
 
 
840
  yield prompter.get_response(outputs, prompt=inputs_decoded,
841
  sanitize_bot_response=sanitize_bot_response)
842
- decoded_output = outputs
843
- else:
844
- outputs = model.generate(**gen_kwargs)
845
- outputs = [decoder(s) for s in outputs.sequences]
846
- yield prompter.get_response(outputs, prompt=inputs_decoded,
847
- sanitize_bot_response=sanitize_bot_response)
848
- if outputs and len(outputs) >= 1:
849
- decoded_output = prompt + outputs[0]
850
- if save_dir and decoded_output:
851
- save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
852
 
853
 
854
  def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
 
5
  import typing
6
  from threading import Thread
7
 
8
+ import filelock
9
+
10
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
11
 
12
  SEED = 1236
 
811
  )
812
 
813
  with torch.no_grad():
814
+ # protection for gradio not keeping track of closed users,
815
+ # else hit bitsandbytes lack of thread safety:
816
+ # https://github.com/h2oai/h2ogpt/issues/104
817
+ # but only makes sense if concurrency_count == 1
818
+ context_class = NullContext if concurrency_count > 1 else filelock.FileLock
819
+ with context_class("generate.lock"):
820
+ # decoded tokenized prompt can deviate from prompt due to special characters
821
+ inputs_decoded = decoder(input_ids[0])
822
+ inputs_decoded_raw = decoder_raw(input_ids[0])
823
+ if inputs_decoded == prompt:
824
+ # normal
825
+ pass
826
+ elif inputs_decoded.lstrip() == prompt.lstrip():
827
+ # sometimes extra space in front, make prompt same for prompt removal
828
+ prompt = inputs_decoded
829
+ elif inputs_decoded_raw == prompt:
830
+ # some models specify special tokens that are part of normal prompt, so can't skip them
831
+ inputs_decoded_raw = inputs_decoded
832
+ decoder = decoder_raw
833
+ else:
834
+ print("WARNING: Special characters in prompt", flush=True)
835
+ decoded_output = None
836
+ if stream_output:
837
+ skip_prompt = False
838
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
839
+ gen_kwargs.update(dict(streamer=streamer))
840
+ target_func = generate_with_exceptions
841
+ target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
842
+ raise_generate_gpu_exceptions, **gen_kwargs)
843
+ thread = Thread(target=target)
844
+ thread.start()
845
+ outputs = ""
846
+ for new_text in streamer:
847
+ outputs += new_text
848
+ yield prompter.get_response(outputs, prompt=inputs_decoded,
849
+ sanitize_bot_response=sanitize_bot_response)
850
+ decoded_output = outputs
851
+ else:
852
+ outputs = model.generate(**gen_kwargs)
853
+ outputs = [decoder(s) for s in outputs.sequences]
854
  yield prompter.get_response(outputs, prompt=inputs_decoded,
855
  sanitize_bot_response=sanitize_bot_response)
856
+ if outputs and len(outputs) >= 1:
857
+ decoded_output = prompt + outputs[0]
858
+ if save_dir and decoded_output:
859
+ save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
 
 
 
 
 
 
860
 
861
 
862
  def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):