Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
·
ac63b1e
1
Parent(s):
7a7ff47
Update with h2oGPT hash cf3886c550581e34d9f05d69d2e3438b2a46d7b2
Browse files- generate.py +46 -38
generate.py
CHANGED
@@ -5,6 +5,8 @@ import traceback
|
|
5 |
import typing
|
6 |
from threading import Thread
|
7 |
|
|
|
|
|
8 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
|
9 |
|
10 |
SEED = 1236
|
@@ -809,46 +811,52 @@ def evaluate(
|
|
809 |
)
|
810 |
|
811 |
with torch.no_grad():
|
812 |
-
#
|
813 |
-
|
814 |
-
|
815 |
-
if
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
840 |
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
841 |
sanitize_bot_response=sanitize_bot_response)
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
847 |
-
sanitize_bot_response=sanitize_bot_response)
|
848 |
-
if outputs and len(outputs) >= 1:
|
849 |
-
decoded_output = prompt + outputs[0]
|
850 |
-
if save_dir and decoded_output:
|
851 |
-
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
852 |
|
853 |
|
854 |
def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
|
|
|
5 |
import typing
|
6 |
from threading import Thread
|
7 |
|
8 |
+
import filelock
|
9 |
+
|
10 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
|
11 |
|
12 |
SEED = 1236
|
|
|
811 |
)
|
812 |
|
813 |
with torch.no_grad():
|
814 |
+
# protection for gradio not keeping track of closed users,
|
815 |
+
# else hit bitsandbytes lack of thread safety:
|
816 |
+
# https://github.com/h2oai/h2ogpt/issues/104
|
817 |
+
# but only makes sense if concurrency_count == 1
|
818 |
+
context_class = NullContext if concurrency_count > 1 else filelock.FileLock
|
819 |
+
with context_class("generate.lock"):
|
820 |
+
# decoded tokenized prompt can deviate from prompt due to special characters
|
821 |
+
inputs_decoded = decoder(input_ids[0])
|
822 |
+
inputs_decoded_raw = decoder_raw(input_ids[0])
|
823 |
+
if inputs_decoded == prompt:
|
824 |
+
# normal
|
825 |
+
pass
|
826 |
+
elif inputs_decoded.lstrip() == prompt.lstrip():
|
827 |
+
# sometimes extra space in front, make prompt same for prompt removal
|
828 |
+
prompt = inputs_decoded
|
829 |
+
elif inputs_decoded_raw == prompt:
|
830 |
+
# some models specify special tokens that are part of normal prompt, so can't skip them
|
831 |
+
inputs_decoded_raw = inputs_decoded
|
832 |
+
decoder = decoder_raw
|
833 |
+
else:
|
834 |
+
print("WARNING: Special characters in prompt", flush=True)
|
835 |
+
decoded_output = None
|
836 |
+
if stream_output:
|
837 |
+
skip_prompt = False
|
838 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
|
839 |
+
gen_kwargs.update(dict(streamer=streamer))
|
840 |
+
target_func = generate_with_exceptions
|
841 |
+
target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
|
842 |
+
raise_generate_gpu_exceptions, **gen_kwargs)
|
843 |
+
thread = Thread(target=target)
|
844 |
+
thread.start()
|
845 |
+
outputs = ""
|
846 |
+
for new_text in streamer:
|
847 |
+
outputs += new_text
|
848 |
+
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
849 |
+
sanitize_bot_response=sanitize_bot_response)
|
850 |
+
decoded_output = outputs
|
851 |
+
else:
|
852 |
+
outputs = model.generate(**gen_kwargs)
|
853 |
+
outputs = [decoder(s) for s in outputs.sequences]
|
854 |
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
855 |
sanitize_bot_response=sanitize_bot_response)
|
856 |
+
if outputs and len(outputs) >= 1:
|
857 |
+
decoded_output = prompt + outputs[0]
|
858 |
+
if save_dir and decoded_output:
|
859 |
+
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
860 |
|
861 |
|
862 |
def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
|