Spaces:
Running
Running
pseudotensor
commited on
Commit
•
8910711
1
Parent(s):
2f10edd
Update with h2oGPT hash 5089a15c88b6f91136ce9c946677b658ffebf13a
Browse files- app.py +571 -280
- client_test.py +22 -50
- finetune.py +2 -2
- utils.py +11 -8
app.py
CHANGED
@@ -31,6 +31,8 @@ is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
|
|
31 |
is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
|
32 |
is_low_mem = is_hf # assumes run on 24GB consumer GPU
|
33 |
admin_pass = os.getenv("ADMIN_PASS")
|
|
|
|
|
34 |
|
35 |
|
36 |
def main(
|
@@ -40,7 +42,7 @@ def main(
|
|
40 |
base_model: str = '',
|
41 |
tokenizer_base_model: str = '',
|
42 |
lora_weights: str = "",
|
43 |
-
|
44 |
|
45 |
prompt_type: Union[int, str] = None,
|
46 |
# input to generation
|
@@ -144,7 +146,8 @@ def main(
|
|
144 |
# override default examples with shareGPT ones for human-level eval purposes only
|
145 |
filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
|
146 |
if not os.path.isfile(filename):
|
147 |
-
os.system(
|
|
|
148 |
import json
|
149 |
data = json.load(open(filename, 'rt'))
|
150 |
# focus on data that starts with human, else likely chopped from other data
|
@@ -228,10 +231,11 @@ def main(
|
|
228 |
traceback.print_exc()
|
229 |
score = 0.0
|
230 |
clear_torch_cache()
|
231 |
-
except RuntimeError as e:
|
232 |
if 'Expected all tensors to be on the same device' in str(e) or \
|
233 |
'expected scalar type Half but found Float' in str(e) or \
|
234 |
-
'probability tensor contains either' in str(e)
|
|
|
235 |
print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
|
236 |
flush=True)
|
237 |
traceback.print_exc()
|
@@ -250,12 +254,13 @@ def main(
|
|
250 |
else:
|
251 |
used_base_model = str(base_model.split('/')[-1])
|
252 |
used_lora_weights = str(lora_weights.split('/')[-1])
|
253 |
-
df_scores = pd.DataFrame(score_dump,
|
|
|
254 |
filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
filename = os.path.join(scoring_path, filename)
|
260 |
df_scores.to_parquet(filename, index=False)
|
261 |
# plot histogram so far
|
@@ -287,7 +292,9 @@ def get_device():
|
|
287 |
return device
|
288 |
|
289 |
|
290 |
-
def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
|
|
|
|
291 |
"""
|
292 |
Ensure model gets on correct device
|
293 |
:param base_model:
|
@@ -295,6 +302,8 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
295 |
:param load_half:
|
296 |
:param model_kwargs:
|
297 |
:param reward_type:
|
|
|
|
|
298 |
:return:
|
299 |
"""
|
300 |
with init_empty_weights():
|
@@ -319,14 +328,14 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
319 |
device_map.update(device_map_model)
|
320 |
print('device_map: %s' % device_map, flush=True)
|
321 |
|
322 |
-
if
|
323 |
# FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
|
324 |
# So avoid for now, just put on first GPU, unless score_model, put on last
|
325 |
n_gpus = torch.cuda.device_count()
|
326 |
if reward_type:
|
327 |
device_map = {'': n_gpus - 1}
|
328 |
else:
|
329 |
-
device_map = {'':
|
330 |
|
331 |
load_in_8bit = model_kwargs.get('load_in_8bit', False)
|
332 |
model_kwargs['device_map'] = device_map
|
@@ -351,7 +360,7 @@ def get_model(
|
|
351 |
base_model: str = '',
|
352 |
tokenizer_base_model: str = '',
|
353 |
lora_weights: str = "",
|
354 |
-
|
355 |
|
356 |
llama_type: bool = None,
|
357 |
reward_type: bool = None,
|
@@ -371,7 +380,7 @@ def get_model(
|
|
371 |
:param base_model: name/path of base model
|
372 |
:param tokenizer_base_model: name/path of tokenizer
|
373 |
:param lora_weights: name/path
|
374 |
-
:param
|
375 |
:param llama_type: whether LLaMa type model
|
376 |
:param reward_type: reward type model for sequence classification
|
377 |
:param local_files_only: use local files instead of from HF
|
@@ -432,7 +441,7 @@ def get_model(
|
|
432 |
with torch.device("cuda"):
|
433 |
if infer_devices:
|
434 |
model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
435 |
-
|
436 |
else:
|
437 |
if load_half and not load_8bit:
|
438 |
model = model_loader.from_pretrained(
|
@@ -511,7 +520,6 @@ def get_score_model(**kwargs):
|
|
511 |
|
512 |
|
513 |
def go_gradio(**kwargs):
|
514 |
-
|
515 |
# get default model
|
516 |
all_kwargs = kwargs.copy()
|
517 |
all_kwargs.update(locals())
|
@@ -526,11 +534,10 @@ def go_gradio(**kwargs):
|
|
526 |
smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
|
527 |
|
528 |
if 'mbart-' in kwargs['model_lower']:
|
529 |
-
|
530 |
else:
|
531 |
-
|
532 |
-
|
533 |
-
instruction_label = "You (Shift-Enter or push Submit to send message)"
|
534 |
|
535 |
title = 'h2oGPT'
|
536 |
if kwargs['verbose']:
|
@@ -542,9 +549,9 @@ def go_gradio(**kwargs):
|
|
542 |
else:
|
543 |
description = "For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).<br>"
|
544 |
if is_public:
|
545 |
-
description += """<p><b> DISCLAIMERS: </b><ul><i><li>The
|
546 |
if kwargs['load_8bit']:
|
547 |
-
description += """<i><li> Model is loaded in 8-bit
|
548 |
description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
|
549 |
description += """<i><li>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md).</i></li></ul></p>"""
|
550 |
|
@@ -630,6 +637,7 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
630 |
return chat_message
|
631 |
else:
|
632 |
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
|
|
633 |
Chatbot._postprocess_chat_messages = _postprocess_chat_messages
|
634 |
|
635 |
demo = gr.Blocks(theme=gr.themes.Soft(**colors_dict), css=css_code, title="h2oGPT", analytics_enabled=False)
|
@@ -645,14 +653,32 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
645 |
lora_options = [kwargs['lora_weights'].strip()] + lora_options
|
646 |
# always add in no lora case
|
647 |
# add fake space so doesn't go away in gradio dropdown
|
648 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
|
650 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
651 |
|
652 |
with demo:
|
653 |
# avoid actual model/tokenizer here or anything that would be bad to deepcopy
|
654 |
# https://github.com/gradio-app/gradio/issues/3558
|
655 |
model_state = gr.State(['model', 'tokenizer', device, kwargs['base_model']])
|
|
|
656 |
model_options_state = gr.State([model_options])
|
657 |
lora_options_state = gr.State([lora_options])
|
658 |
gr.Markdown(
|
@@ -663,57 +689,69 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
663 |
{task_info_md}
|
664 |
""")
|
665 |
if is_hf:
|
666 |
-
gr.HTML(
|
|
|
667 |
|
668 |
# go button visible if
|
669 |
-
base_wanted =
|
670 |
go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
|
671 |
normal_block = gr.Row(visible=not base_wanted)
|
672 |
with normal_block:
|
673 |
with gr.Tabs():
|
674 |
with gr.Row():
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
flag_btn = gr.Button("Flag")
|
684 |
if kwargs['score_model']:
|
685 |
-
if not kwargs['auto_score']:
|
686 |
with gr.Column():
|
687 |
-
|
688 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
689 |
else:
|
690 |
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
with gr.Row():
|
695 |
-
with gr.Column(scale=50):
|
696 |
-
instruction = gr.Textbox(
|
697 |
-
lines=4, label=instruction_label,
|
698 |
-
placeholder=kwargs['placeholder_instruction'],
|
699 |
-
)
|
700 |
-
with gr.Row(): # .style(equal_height=False, equal_width=False):
|
701 |
-
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
|
702 |
-
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
|
703 |
-
with gr.Row():
|
704 |
-
clear = gr.Button("New Conversation")
|
705 |
-
flag_btn = gr.Button("Flag")
|
706 |
-
if kwargs['score_model']:
|
707 |
-
if not kwargs['auto_score']:
|
708 |
-
with gr.Column():
|
709 |
-
score_btn = gr.Button("Score last prompt & response").style(full_width=False, size='sm')
|
710 |
-
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
711 |
-
else:
|
712 |
-
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
713 |
-
retry = gr.Button("Regenerate")
|
714 |
-
undo = gr.Button("Undo")
|
715 |
-
else:
|
716 |
-
text_output = gr.Textbox(lines=5, label=output_label0)
|
717 |
with gr.TabItem("Input/Output"):
|
718 |
with gr.Row():
|
719 |
if 'mbart-' in kwargs['model_lower']:
|
@@ -731,6 +769,11 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
731 |
prompt_type = gr.Dropdown(prompt_types_strings,
|
732 |
value=kwargs['prompt_type'], label="Prompt Type",
|
733 |
visible=not is_public)
|
|
|
|
|
|
|
|
|
|
|
734 |
temperature = gr.Slider(minimum=0, maximum=3,
|
735 |
value=kwargs['temperature'],
|
736 |
label="Temperature",
|
@@ -770,30 +813,45 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
770 |
value=kwargs['num_return_sequences'],
|
771 |
label="Number Returns", info="Must be <= num_beams",
|
772 |
visible=not is_public)
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
info="Ignored in chat mode.",
|
782 |
-
visible=not is_public)
|
783 |
|
784 |
with gr.TabItem("Models"):
|
|
|
|
|
|
|
|
|
|
|
|
|
785 |
with gr.Row():
|
|
|
|
|
786 |
with gr.Column():
|
787 |
with gr.Row(scale=1):
|
788 |
with gr.Column(scale=50):
|
789 |
-
model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
|
790 |
-
|
|
|
|
|
791 |
with gr.Column(scale=1):
|
792 |
-
load_msg = "Load Model/LORA" if not is_public \
|
793 |
-
else "LOAD DISABLED FOR HOSTED DEMO"
|
794 |
load_model_button = gr.Button(load_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
795 |
model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
|
796 |
-
lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
|
|
|
797 |
with gr.Row(scale=1):
|
798 |
with gr.Column(scale=50):
|
799 |
new_model = gr.Textbox(label="New Model HF name/path")
|
@@ -801,6 +859,30 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
801 |
with gr.Column(scale=1):
|
802 |
add_model_button = gr.Button("Add new model name")
|
803 |
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
804 |
with gr.TabItem("System"):
|
805 |
system_row = gr.Row(visible=not is_public)
|
806 |
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
|
@@ -830,6 +912,9 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
830 |
kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
|
831 |
fun = partial(evaluate,
|
832 |
**kwargs_evaluate)
|
|
|
|
|
|
|
833 |
|
834 |
dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
|
835 |
size="sm",
|
@@ -847,193 +932,315 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
847 |
}""",
|
848 |
api_name="dark",
|
849 |
)
|
850 |
-
|
851 |
-
|
852 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
853 |
|
854 |
# examples after submit or any other buttons for chat or no chat
|
855 |
if kwargs['examples'] is not None and kwargs['show_examples']:
|
856 |
gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
|
857 |
|
858 |
# Score
|
859 |
-
def score_last_response(*args):
|
860 |
""" Similar to user() """
|
861 |
args_list = list(args)
|
862 |
-
|
863 |
-
if
|
864 |
-
|
865 |
-
|
866 |
-
if
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
892 |
traceback.print_exc()
|
893 |
clear_torch_cache()
|
894 |
-
return 'Response Score: GPU
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
|
900 |
-
traceback.print_exc()
|
901 |
-
clear_torch_cache()
|
902 |
-
return 'Response Score: GPU Error'
|
903 |
-
else:
|
904 |
-
raise
|
905 |
-
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
906 |
-
return 'Response Score: {:.1%}'.format(score)
|
907 |
-
else:
|
908 |
-
return 'Response Score: NA'
|
909 |
|
910 |
if kwargs['score_model']:
|
911 |
score_args = dict(fn=score_last_response,
|
912 |
inputs=inputs_list + [text_output],
|
913 |
outputs=[score_text],
|
914 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
915 |
if not kwargs['auto_score']:
|
916 |
-
score_event = score_btn.click(**score_args, queue=stream_output, api_name='score')
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
933 |
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
|
|
|
|
939 |
print("Bad history, fix for now", flush=True)
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
|
947 |
-
|
948 |
-
history
|
949 |
-
|
950 |
-
|
951 |
-
|
952 |
-
|
953 |
-
|
954 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
955 |
context1 = ''
|
956 |
-
|
957 |
-
|
958 |
-
context1 =
|
959 |
-
|
960 |
-
|
961 |
-
context1 +=
|
962 |
-
|
963 |
-
|
964 |
-
|
965 |
-
|
966 |
-
|
967 |
-
|
968 |
-
|
969 |
-
|
970 |
-
|
971 |
-
|
972 |
-
|
973 |
-
|
974 |
-
|
975 |
-
|
976 |
-
|
977 |
-
|
978 |
-
history[-1][1] = bot_message
|
979 |
-
yield history
|
980 |
-
except StopIteration:
|
981 |
yield history
|
982 |
-
|
983 |
-
|
984 |
-
|
985 |
-
|
986 |
-
|
987 |
-
|
988 |
-
except Exception as e:
|
989 |
-
# put error into user input
|
990 |
-
history[-1][0] = "Exception: %s" % str(e)
|
991 |
yield history
|
992 |
-
|
993 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
994 |
|
995 |
-
|
996 |
-
|
997 |
-
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
-
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
|
1019 |
-
|
1020 |
-
|
1021 |
-
|
1022 |
-
|
1023 |
-
|
1024 |
-
|
1025 |
-
|
1026 |
-
|
1027 |
-
|
1028 |
-
|
1029 |
-
|
1030 |
-
|
1031 |
-
|
1032 |
-
|
1033 |
-
|
1034 |
-
|
1035 |
-
|
1036 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1037 |
# ensure old model removed from GPU memory
|
1038 |
if kwargs['debug']:
|
1039 |
print("Pre-switch pre-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
@@ -1058,23 +1265,35 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
1058 |
clear_torch_cache()
|
1059 |
if kwargs['debug']:
|
1060 |
print("Pre-switch post-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
1061 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1062 |
model_lower = model_name.strip().lower()
|
1063 |
if model_lower in inv_prompt_type_to_model_lower:
|
1064 |
prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
|
1065 |
else:
|
1066 |
prompt_type1 = prompt_type_old
|
1067 |
|
1068 |
-
|
1069 |
-
|
|
|
|
|
|
|
|
|
1070 |
clear_torch_cache()
|
1071 |
|
1072 |
if kwargs['debug']:
|
1073 |
print("Post-switch GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
1074 |
-
return
|
1075 |
-
model_used: model_name,
|
1076 |
-
lora_used: lora_weights,
|
1077 |
-
prompt_type: prompt_type1}
|
1078 |
|
1079 |
def dropdown_prompt_type_list(x):
|
1080 |
return gr.Dropdown.update(value=x)
|
@@ -1083,54 +1302,90 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
1083 |
return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
|
1084 |
|
1085 |
load_model_args = dict(fn=load_model,
|
1086 |
-
inputs=[model_choice, lora_choice, model_state, prompt_type
|
|
|
1087 |
outputs=[model_state, model_used, lora_used, prompt_type])
|
1088 |
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
|
1089 |
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
|
1090 |
if not is_public:
|
1091 |
load_model_event = load_model_button.click(**load_model_args) \
|
1092 |
-
|
1093 |
-
|
1094 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1095 |
|
1096 |
def dropdown_model_list(list0, x):
|
1097 |
new_state = [list0[0] + [x]]
|
1098 |
new_options = [*new_state[0]]
|
1099 |
-
return gr.Dropdown.update(value=x, choices=new_options),
|
|
|
|
|
1100 |
|
1101 |
add_model_event = add_model_button.click(fn=dropdown_model_list,
|
1102 |
inputs=[model_options_state, new_model],
|
1103 |
-
outputs=[model_choice, new_model, model_options_state])
|
1104 |
|
1105 |
-
def dropdown_lora_list(list0, x):
|
1106 |
new_state = [list0[0] + [x]]
|
1107 |
new_options = [*new_state[0]]
|
1108 |
-
|
|
|
|
|
|
|
|
|
|
|
1109 |
|
1110 |
add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
|
1111 |
-
inputs=[lora_options_state, new_lora],
|
1112 |
-
outputs=[lora_choice, new_lora, lora_options_state])
|
1113 |
|
1114 |
go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go") \
|
1115 |
.then(lambda: gr.update(visible=True), None, normal_block) \
|
1116 |
.then(**load_model_args).then(**prompt_update_args)
|
1117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1118 |
# callback for logging flagged input/output
|
1119 |
callback.setup(inputs_list + [text_output], "flagged_data_points")
|
1120 |
flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
|
1121 |
api_name='flag')
|
|
|
|
|
1122 |
|
1123 |
def get_system_info():
|
1124 |
return gr.Textbox.update(value=system_info_print())
|
1125 |
|
1126 |
system_event = system_btn.click(get_system_info, outputs=system_text, api_name='system_info')
|
1127 |
|
1128 |
-
|
1129 |
-
|
1130 |
-
|
1131 |
-
|
1132 |
-
|
1133 |
-
queue=False, api_name='stop').then(clear_torch_cache)
|
1134 |
|
1135 |
demo.queue(concurrency_count=1)
|
1136 |
favicon_path = "h2o-logo.svg"
|
@@ -1141,10 +1396,16 @@ body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
|
|
1141 |
|
1142 |
|
1143 |
input_args_list = ['model_state']
|
1144 |
-
inputs_kwargs_list = ['debug', '
|
1145 |
|
1146 |
|
1147 |
def get_inputs_list(inputs_dict, model_lower):
|
|
|
|
|
|
|
|
|
|
|
|
|
1148 |
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
1149 |
inputs_list = []
|
1150 |
for k in inputs_list_names:
|
@@ -1159,9 +1420,6 @@ def get_inputs_list(inputs_dict, model_lower):
|
|
1159 |
return inputs_list
|
1160 |
|
1161 |
|
1162 |
-
# index of prompt_type in evaluate function, after model_state
|
1163 |
-
prompt_type_arg_id = 4
|
1164 |
-
|
1165 |
eval_func_param_names = ['instruction',
|
1166 |
'iinput',
|
1167 |
'context',
|
@@ -1178,6 +1436,9 @@ eval_func_param_names = ['instruction',
|
|
1178 |
'repetition_penalty',
|
1179 |
'num_return_sequences',
|
1180 |
'do_sample',
|
|
|
|
|
|
|
1181 |
]
|
1182 |
|
1183 |
|
@@ -1200,12 +1461,14 @@ def evaluate(
|
|
1200 |
repetition_penalty,
|
1201 |
num_return_sequences,
|
1202 |
do_sample,
|
|
|
|
|
|
|
1203 |
# END NOTE: Examples must have same order of parameters
|
1204 |
src_lang=None,
|
1205 |
tgt_lang=None,
|
1206 |
debug=False,
|
1207 |
save_dir=None,
|
1208 |
-
chat=False,
|
1209 |
hard_stop_list=None,
|
1210 |
sanitize_bot_response=True,
|
1211 |
model_state0=None,
|
@@ -1214,10 +1477,15 @@ def evaluate(
|
|
1214 |
if debug:
|
1215 |
locals_dict = locals().copy()
|
1216 |
locals_dict.pop('model_state', None)
|
|
|
1217 |
print(locals_dict)
|
1218 |
|
1219 |
no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
|
1220 |
|
|
|
|
|
|
|
|
|
1221 |
if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
|
1222 |
# try to free-up original model (i.e. list was passed as reference)
|
1223 |
if model_state0 is not None and model_state0[0] is not None:
|
@@ -1234,10 +1502,18 @@ def evaluate(
|
|
1234 |
else:
|
1235 |
raise AssertionError(no_model_msg)
|
1236 |
|
|
|
|
|
|
|
1237 |
assert base_model.strip(), no_model_msg
|
1238 |
assert model, "Model is missing"
|
1239 |
assert tokenizer, "Tokenizer is missing"
|
1240 |
|
|
|
|
|
|
|
|
|
|
|
1241 |
data_point = dict(context=context, instruction=instruction, input=iinput)
|
1242 |
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
1243 |
prompt = prompter.generate_prompt(data_point)
|
@@ -1272,16 +1548,16 @@ def evaluate(
|
|
1272 |
elif prompt_type == 'instruct_vicuna':
|
1273 |
# even below is not enough, generic strings and many ways to encode
|
1274 |
stop_words = [
|
1275 |
-
|
1276 |
-
|
1277 |
### Human:""",
|
1278 |
-
|
1279 |
### Human:
|
1280 |
""",
|
1281 |
-
|
1282 |
-
|
1283 |
### Assistant:""",
|
1284 |
-
|
1285 |
### Assistant:
|
1286 |
""",
|
1287 |
]
|
@@ -1299,7 +1575,7 @@ def evaluate(
|
|
1299 |
if tokenizer.pad_token:
|
1300 |
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
1301 |
# handle fake \n added
|
1302 |
-
stop_words_ids = [x[1:] if y[0] == '\n' else x for x,y in zip(stop_words_ids, stop_words)]
|
1303 |
# build stopper
|
1304 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
|
1305 |
else:
|
@@ -1397,15 +1673,18 @@ def evaluate(
|
|
1397 |
traceback.print_exc()
|
1398 |
clear_torch_cache()
|
1399 |
return
|
1400 |
-
except RuntimeError as e:
|
1401 |
if 'Expected all tensors to be on the same device' in str(e) or \
|
1402 |
'expected scalar type Half but found Float' in str(e) or \
|
1403 |
-
'probability tensor contains either' in str(e)
|
|
|
1404 |
print(
|
1405 |
"GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
|
1406 |
flush=True)
|
1407 |
traceback.print_exc()
|
1408 |
clear_torch_cache()
|
|
|
|
|
1409 |
return
|
1410 |
else:
|
1411 |
raise
|
@@ -1516,7 +1795,8 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa
|
|
1516 |
else:
|
1517 |
prompt_type = ''
|
1518 |
examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
|
1519 |
-
stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1,
|
|
|
1520 |
task_info = "No task"
|
1521 |
if prompt_type == 'instruct':
|
1522 |
task_info = "Answer question or follow imperative as instruction with optionally input."
|
@@ -1594,6 +1874,17 @@ y = np.random.randint(0, 1, 100)
|
|
1594 |
src_lang = "English"
|
1595 |
tgt_lang = "Russian"
|
1596 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1597 |
return placeholder_instruction, placeholder_input, \
|
1598 |
stream_output, show_examples, \
|
1599 |
prompt_type, temperature, top_p, top_k, num_beams, \
|
|
|
31 |
is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
|
32 |
is_low_mem = is_hf # assumes run on 24GB consumer GPU
|
33 |
admin_pass = os.getenv("ADMIN_PASS")
|
34 |
+
# will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
|
35 |
+
raise_generate_gpu_exceptions = True
|
36 |
|
37 |
|
38 |
def main(
|
|
|
42 |
base_model: str = '',
|
43 |
tokenizer_base_model: str = '',
|
44 |
lora_weights: str = "",
|
45 |
+
gpu_id: int = 0, # if infer_devices = True and gpu_id != -1
|
46 |
|
47 |
prompt_type: Union[int, str] = None,
|
48 |
# input to generation
|
|
|
146 |
# override default examples with shareGPT ones for human-level eval purposes only
|
147 |
filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
|
148 |
if not os.path.isfile(filename):
|
149 |
+
os.system(
|
150 |
+
'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
|
151 |
import json
|
152 |
data = json.load(open(filename, 'rt'))
|
153 |
# focus on data that starts with human, else likely chopped from other data
|
|
|
231 |
traceback.print_exc()
|
232 |
score = 0.0
|
233 |
clear_torch_cache()
|
234 |
+
except (Exception, RuntimeError) as e:
|
235 |
if 'Expected all tensors to be on the same device' in str(e) or \
|
236 |
'expected scalar type Half but found Float' in str(e) or \
|
237 |
+
'probability tensor contains either' in str(e) or \
|
238 |
+
'cublasLt ran into an error!' in str(e):
|
239 |
print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
|
240 |
flush=True)
|
241 |
traceback.print_exc()
|
|
|
254 |
else:
|
255 |
used_base_model = str(base_model.split('/')[-1])
|
256 |
used_lora_weights = str(lora_weights.split('/')[-1])
|
257 |
+
df_scores = pd.DataFrame(score_dump,
|
258 |
+
columns=eval_func_param_names + ['prompt', 'response', 'score'])
|
259 |
filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
|
260 |
+
eval_sharegpt_prompts_only_seed,
|
261 |
+
eval_sharegpt_as_output,
|
262 |
+
used_base_model,
|
263 |
+
used_lora_weights)
|
264 |
filename = os.path.join(scoring_path, filename)
|
265 |
df_scores.to_parquet(filename, index=False)
|
266 |
# plot histogram so far
|
|
|
292 |
return device
|
293 |
|
294 |
|
295 |
+
def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
296 |
+
gpu_id=0,
|
297 |
+
use_auth_token=False):
|
298 |
"""
|
299 |
Ensure model gets on correct device
|
300 |
:param base_model:
|
|
|
302 |
:param load_half:
|
303 |
:param model_kwargs:
|
304 |
:param reward_type:
|
305 |
+
:param gpu_id:
|
306 |
+
:param use_auth_token:
|
307 |
:return:
|
308 |
"""
|
309 |
with init_empty_weights():
|
|
|
328 |
device_map.update(device_map_model)
|
329 |
print('device_map: %s' % device_map, flush=True)
|
330 |
|
331 |
+
if gpu_id >= 0:
|
332 |
# FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
|
333 |
# So avoid for now, just put on first GPU, unless score_model, put on last
|
334 |
n_gpus = torch.cuda.device_count()
|
335 |
if reward_type:
|
336 |
device_map = {'': n_gpus - 1}
|
337 |
else:
|
338 |
+
device_map = {'': min(n_gpus - 1, gpu_id)}
|
339 |
|
340 |
load_in_8bit = model_kwargs.get('load_in_8bit', False)
|
341 |
model_kwargs['device_map'] = device_map
|
|
|
360 |
base_model: str = '',
|
361 |
tokenizer_base_model: str = '',
|
362 |
lora_weights: str = "",
|
363 |
+
gpu_id: int = 0,
|
364 |
|
365 |
llama_type: bool = None,
|
366 |
reward_type: bool = None,
|
|
|
380 |
:param base_model: name/path of base model
|
381 |
:param tokenizer_base_model: name/path of tokenizer
|
382 |
:param lora_weights: name/path
|
383 |
+
:param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
|
384 |
:param llama_type: whether LLaMa type model
|
385 |
:param reward_type: reward type model for sequence classification
|
386 |
:param local_files_only: use local files instead of from HF
|
|
|
441 |
with torch.device("cuda"):
|
442 |
if infer_devices:
|
443 |
model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
444 |
+
gpu_id=gpu_id, use_auth_token=use_auth_token)
|
445 |
else:
|
446 |
if load_half and not load_8bit:
|
447 |
model = model_loader.from_pretrained(
|
|
|
520 |
|
521 |
|
522 |
def go_gradio(**kwargs):
|
|
|
523 |
# get default model
|
524 |
all_kwargs = kwargs.copy()
|
525 |
all_kwargs.update(locals())
|
|
|
534 |
smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
|
535 |
|
536 |
if 'mbart-' in kwargs['model_lower']:
|
537 |
+
instruction_label_nochat = "Text to translate"
|
538 |
else:
|
539 |
+
instruction_label_nochat = "Instruction"
|
540 |
+
instruction_label = "You (Shift-Enter or push Submit to send message)"
|
|
|
541 |
|
542 |
title = 'h2oGPT'
|
543 |
if kwargs['verbose']:
|
|
|
549 |
else:
|
550 |
description = "For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).<br>"
|
551 |
if is_public:
|
552 |
+
description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.</i></li>"""
|
553 |
if kwargs['load_8bit']:
|
554 |
+
description += """<i><li> Model is loaded in 8-bit and has other restrictions on this host. UX can be worse than non-hosted version.</i></li>"""
|
555 |
description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
|
556 |
description += """<i><li>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md).</i></li></ul></p>"""
|
557 |
|
|
|
637 |
return chat_message
|
638 |
else:
|
639 |
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
640 |
+
|
641 |
Chatbot._postprocess_chat_messages = _postprocess_chat_messages
|
642 |
|
643 |
demo = gr.Blocks(theme=gr.themes.Soft(**colors_dict), css=css_code, title="h2oGPT", analytics_enabled=False)
|
|
|
653 |
lora_options = [kwargs['lora_weights'].strip()] + lora_options
|
654 |
# always add in no lora case
|
655 |
# add fake space so doesn't go away in gradio dropdown
|
656 |
+
no_lora_str = no_model_str = '[None/Remove]'
|
657 |
+
lora_options = [no_lora_str] + kwargs['extra_lora_options'] # FIXME: why double?
|
658 |
+
# always add in no model case so can free memory
|
659 |
+
# add fake space so doesn't go away in gradio dropdown
|
660 |
+
model_options = [no_model_str] + model_options
|
661 |
+
|
662 |
+
# transcribe, will be detranscribed before use by evaluate()
|
663 |
+
if not kwargs['lora_weights'].strip():
|
664 |
+
kwargs['lora_weights'] = no_lora_str
|
665 |
|
666 |
+
if not kwargs['base_model'].strip():
|
667 |
+
kwargs['base_model'] = no_model_str
|
668 |
+
|
669 |
+
# transcribe for gradio
|
670 |
+
kwargs['gpu_id'] = str(kwargs['gpu_id'])
|
671 |
+
|
672 |
+
no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
|
673 |
+
output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
|
674 |
+
'base_model') else no_model_msg
|
675 |
+
output_label0_model2 = no_model_msg
|
676 |
|
677 |
with demo:
|
678 |
# avoid actual model/tokenizer here or anything that would be bad to deepcopy
|
679 |
# https://github.com/gradio-app/gradio/issues/3558
|
680 |
model_state = gr.State(['model', 'tokenizer', device, kwargs['base_model']])
|
681 |
+
model_state2 = gr.State([None, None, None, None])
|
682 |
model_options_state = gr.State([model_options])
|
683 |
lora_options_state = gr.State([lora_options])
|
684 |
gr.Markdown(
|
|
|
689 |
{task_info_md}
|
690 |
""")
|
691 |
if is_hf:
|
692 |
+
gr.HTML(
|
693 |
+
'''<center><a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate this Space to skip the queue and run in a private space</center>''')
|
694 |
|
695 |
# go button visible if
|
696 |
+
base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
|
697 |
go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
|
698 |
normal_block = gr.Row(visible=not base_wanted)
|
699 |
with normal_block:
|
700 |
with gr.Tabs():
|
701 |
with gr.Row():
|
702 |
+
col_nochat = gr.Column(visible=not kwargs['chat'])
|
703 |
+
with col_nochat: # FIXME: for model comparison, and check rest
|
704 |
+
text_output_nochat = gr.Textbox(lines=5, label=output_label0)
|
705 |
+
instruction_nochat = gr.Textbox(
|
706 |
+
lines=4, label=instruction_label_nochat,
|
707 |
+
placeholder=kwargs['placeholder_instruction'],
|
708 |
+
)
|
709 |
+
iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
|
710 |
+
placeholder=kwargs['placeholder_input'])
|
711 |
+
submit_nochat = gr.Button("Submit")
|
712 |
+
flag_btn_nochat = gr.Button("Flag")
|
713 |
+
if kwargs['score_model']:
|
714 |
+
if not kwargs['auto_score']:
|
715 |
+
with gr.Column():
|
716 |
+
score_btn_nochat = gr.Button("Score last prompt & response")
|
717 |
+
score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
|
718 |
+
else:
|
719 |
+
score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
|
720 |
+
col_chat = gr.Column(visible=kwargs['chat'])
|
721 |
+
with col_chat:
|
722 |
+
with gr.Row():
|
723 |
+
text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
|
724 |
+
text_output2 = gr.Chatbot(label=output_label0_model2, visible=False).style(
|
725 |
+
height=kwargs['height'] or 400)
|
726 |
+
with gr.Row():
|
727 |
+
with gr.Column(scale=50):
|
728 |
+
instruction = gr.Textbox(
|
729 |
+
lines=4, label=instruction_label,
|
730 |
+
placeholder=kwargs['placeholder_instruction'],
|
731 |
+
)
|
732 |
+
with gr.Row(): # .style(equal_height=False, equal_width=False):
|
733 |
+
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
|
734 |
+
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
|
735 |
+
with gr.Row():
|
736 |
+
clear = gr.Button("New Conversation")
|
737 |
flag_btn = gr.Button("Flag")
|
738 |
if kwargs['score_model']:
|
739 |
+
if not kwargs['auto_score']: # FIXME: For checkbox model2
|
740 |
with gr.Column():
|
741 |
+
with gr.Row():
|
742 |
+
score_btn = gr.Button("Score last prompt & response").style(
|
743 |
+
full_width=False, size='sm')
|
744 |
+
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
745 |
+
score_res2 = gr.Row(visible=False)
|
746 |
+
with score_res2:
|
747 |
+
score_btn2 = gr.Button("Score last prompt & response 2").style(
|
748 |
+
full_width=False, size='sm')
|
749 |
+
score_text2 = gr.Textbox("Response Score2: NA", show_label=False)
|
750 |
else:
|
751 |
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
752 |
+
score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
|
753 |
+
retry = gr.Button("Regenerate")
|
754 |
+
undo = gr.Button("Undo")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
755 |
with gr.TabItem("Input/Output"):
|
756 |
with gr.Row():
|
757 |
if 'mbart-' in kwargs['model_lower']:
|
|
|
769 |
prompt_type = gr.Dropdown(prompt_types_strings,
|
770 |
value=kwargs['prompt_type'], label="Prompt Type",
|
771 |
visible=not is_public)
|
772 |
+
prompt_type2 = gr.Dropdown(prompt_types_strings,
|
773 |
+
value=kwargs['prompt_type'], label="Prompt Type Model 2",
|
774 |
+
visible=not is_public and False)
|
775 |
+
do_sample = gr.Checkbox(label="Sample", info="Enable sampler, required for use of temperature, top_p, top_k",
|
776 |
+
value=kwargs['do_sample'])
|
777 |
temperature = gr.Slider(minimum=0, maximum=3,
|
778 |
value=kwargs['temperature'],
|
779 |
label="Temperature",
|
|
|
813 |
value=kwargs['num_return_sequences'],
|
814 |
label="Number Returns", info="Must be <= num_beams",
|
815 |
visible=not is_public)
|
816 |
+
iinput = gr.Textbox(lines=4, label="Input",
|
817 |
+
placeholder=kwargs['placeholder_input'],
|
818 |
+
visible=not is_public)
|
819 |
+
context = gr.Textbox(lines=3, label="System Pre-Context",
|
820 |
+
info="Directly pre-appended without prompt processing",
|
821 |
+
visible=not is_public and not kwargs['chat'])
|
822 |
+
chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
|
823 |
+
visible=not is_public)
|
|
|
|
|
824 |
|
825 |
with gr.TabItem("Models"):
|
826 |
+
load_msg = "Load-Unload Model/LORA" if not is_public \
|
827 |
+
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
|
828 |
+
load_msg2 = "Load-Unload Model/LORA 2" if not is_public \
|
829 |
+
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
|
830 |
+
compare_checkbox = gr.components.Checkbox(label="Compare Mode",
|
831 |
+
value=False, visible=not is_public)
|
832 |
with gr.Row():
|
833 |
+
n_gpus = torch.cuda.device_count()
|
834 |
+
n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
|
835 |
with gr.Column():
|
836 |
with gr.Row(scale=1):
|
837 |
with gr.Column(scale=50):
|
838 |
+
model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
|
839 |
+
value=kwargs['base_model'])
|
840 |
+
lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
|
841 |
+
value=kwargs['lora_weights'], visible=kwargs['show_lora'])
|
842 |
with gr.Column(scale=1):
|
|
|
|
|
843 |
load_model_button = gr.Button(load_msg)
|
844 |
+
model_load8bit_checkbox = gr.components.Checkbox(
|
845 |
+
label="Load 8-bit [Not all models support]",
|
846 |
+
value=kwargs['load_8bit'])
|
847 |
+
model_infer_devices_checkbox = gr.components.Checkbox(
|
848 |
+
label="Infer Devices [If GPU ID=-1 or not Checked, then will spread model over GPUs]",
|
849 |
+
value=kwargs['infer_devices'])
|
850 |
+
model_gpu = gr.Dropdown(n_gpus_list, label="GPU ID [-1 = all GPUs]",
|
851 |
+
value=kwargs['gpu_id'])
|
852 |
model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
|
853 |
+
lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
|
854 |
+
visible=kwargs['show_lora'])
|
855 |
with gr.Row(scale=1):
|
856 |
with gr.Column(scale=50):
|
857 |
new_model = gr.Textbox(label="New Model HF name/path")
|
|
|
859 |
with gr.Column(scale=1):
|
860 |
add_model_button = gr.Button("Add new model name")
|
861 |
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
|
862 |
+
col_model2 = gr.Column(visible=False)
|
863 |
+
with col_model2:
|
864 |
+
with gr.Row(scale=1):
|
865 |
+
with gr.Column(scale=50):
|
866 |
+
model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
|
867 |
+
value=no_model_str)
|
868 |
+
lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
|
869 |
+
value=no_lora_str,
|
870 |
+
visible=kwargs['show_lora'])
|
871 |
+
with gr.Column(scale=1):
|
872 |
+
load_model_button2 = gr.Button(load_msg2)
|
873 |
+
model_load8bit_checkbox2 = gr.components.Checkbox(
|
874 |
+
label="Load 8-bit 2 [Not all models support]",
|
875 |
+
value=kwargs['load_8bit'])
|
876 |
+
model_infer_devices_checkbox2 = gr.components.Checkbox(
|
877 |
+
label="Infer Devices 2 [If GPU ID=-1 or not Checked, then will spread model over GPUs]",
|
878 |
+
value=kwargs[
|
879 |
+
'infer_devices'])
|
880 |
+
model_gpu2 = gr.Dropdown(n_gpus_list, label="GPU ID [-1 = all GPUs]",
|
881 |
+
value=kwargs['gpu_id'])
|
882 |
+
# no model/lora loaded ever in model2 by default
|
883 |
+
model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
|
884 |
+
lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
|
885 |
+
visible=kwargs['show_lora'])
|
886 |
with gr.TabItem("System"):
|
887 |
system_row = gr.Row(visible=not is_public)
|
888 |
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
|
|
|
912 |
kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
|
913 |
fun = partial(evaluate,
|
914 |
**kwargs_evaluate)
|
915 |
+
fun2 = partial(evaluate,
|
916 |
+
model_state2,
|
917 |
+
**kwargs_evaluate)
|
918 |
|
919 |
dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
|
920 |
size="sm",
|
|
|
932 |
}""",
|
933 |
api_name="dark",
|
934 |
)
|
935 |
+
|
936 |
+
# Control chat and non-chat blocks, which can be independently used by chat checkbox swap
|
937 |
+
def col_nochat_fun(x):
|
938 |
+
return gr.Column.update(visible=not x)
|
939 |
+
|
940 |
+
def col_chat_fun(x):
|
941 |
+
return gr.Column.update(visible=x)
|
942 |
+
|
943 |
+
def context_fun(x):
|
944 |
+
return gr.Textbox.update(visible=not x)
|
945 |
+
|
946 |
+
chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox") \
|
947 |
+
.then(col_chat_fun, chat, col_chat) \
|
948 |
+
.then(context_fun, chat, context)
|
949 |
|
950 |
# examples after submit or any other buttons for chat or no chat
|
951 |
if kwargs['examples'] is not None and kwargs['show_examples']:
|
952 |
gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
|
953 |
|
954 |
# Score
|
955 |
+
def score_last_response(*args, nochat=False, model2=False):
|
956 |
""" Similar to user() """
|
957 |
args_list = list(args)
|
958 |
+
|
959 |
+
max_length_tokenize = 512 if is_low_mem else 2048
|
960 |
+
cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
|
961 |
+
|
962 |
+
if not nochat:
|
963 |
+
history = args_list[-1]
|
964 |
+
if history is None:
|
965 |
+
if not model2:
|
966 |
+
# maybe only doing first model, no need to complain
|
967 |
+
print("Bad history in scoring last response, fix for now", flush=True)
|
968 |
+
history = []
|
969 |
+
if smodel is not None and \
|
970 |
+
stokenizer is not None and \
|
971 |
+
sdevice is not None and \
|
972 |
+
history is not None and len(history) > 0 and \
|
973 |
+
history[-1] is not None and \
|
974 |
+
len(history[-1]) >= 2:
|
975 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
976 |
+
|
977 |
+
question = history[-1][0]
|
978 |
+
|
979 |
+
answer = history[-1][1]
|
980 |
+
else:
|
981 |
+
return 'Response Score: NA'
|
982 |
+
else:
|
983 |
+
answer = args_list[-1]
|
984 |
+
instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat')
|
985 |
+
question = args_list[instruction_nochat_arg_id]
|
986 |
+
|
987 |
+
question = question[-cutoff_len:]
|
988 |
+
answer = answer[-cutoff_len:]
|
989 |
+
|
990 |
+
inputs = stokenizer(question, answer,
|
991 |
+
return_tensors="pt",
|
992 |
+
truncation=True,
|
993 |
+
max_length=max_length_tokenize).to(smodel.device)
|
994 |
+
try:
|
995 |
+
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
|
996 |
+
except torch.cuda.OutOfMemoryError as e:
|
997 |
+
print("GPU OOM: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
|
998 |
+
del inputs
|
999 |
+
traceback.print_exc()
|
1000 |
+
clear_torch_cache()
|
1001 |
+
return 'Response Score: GPU OOM'
|
1002 |
+
except (Exception, RuntimeError) as e:
|
1003 |
+
if 'Expected all tensors to be on the same device' in str(e) or \
|
1004 |
+
'expected scalar type Half but found Float' in str(e) or \
|
1005 |
+
'probability tensor contains either' in str(e) or \
|
1006 |
+
'cublasLt ran into an error!' in str(e):
|
1007 |
+
print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
|
1008 |
+
flush=True)
|
1009 |
traceback.print_exc()
|
1010 |
clear_torch_cache()
|
1011 |
+
return 'Response Score: GPU Error'
|
1012 |
+
else:
|
1013 |
+
raise
|
1014 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
1015 |
+
return 'Response Score: {:.1%}'.format(score)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1016 |
|
1017 |
if kwargs['score_model']:
|
1018 |
score_args = dict(fn=score_last_response,
|
1019 |
inputs=inputs_list + [text_output],
|
1020 |
outputs=[score_text],
|
1021 |
)
|
1022 |
+
score_args2 = dict(fn=partial(score_last_response, model2=True),
|
1023 |
+
inputs=inputs_list + [text_output2],
|
1024 |
+
outputs=[score_text2],
|
1025 |
+
)
|
1026 |
+
|
1027 |
+
score_args_nochat = dict(fn=partial(score_last_response, nochat=True),
|
1028 |
+
inputs=inputs_list + [text_output_nochat],
|
1029 |
+
outputs=[score_text_nochat],
|
1030 |
+
)
|
1031 |
if not kwargs['auto_score']:
|
1032 |
+
score_event = score_btn.click(**score_args, queue=stream_output, api_name='score') \
|
1033 |
+
.then(**score_args2, queue=stream_output, api_name='score2')
|
1034 |
+
score_event_nochat = score_btn_nochat.click(**score_args_nochat, queue=stream_output,
|
1035 |
+
api_name='score_nochat')
|
1036 |
+
|
1037 |
+
def user(*args, undo=False, sanitize_user_prompt=True, model2=False):
|
1038 |
+
"""
|
1039 |
+
User that fills history for bot
|
1040 |
+
:param args:
|
1041 |
+
:param undo:
|
1042 |
+
:param sanitize_user_prompt:
|
1043 |
+
:param model2:
|
1044 |
+
:return:
|
1045 |
+
"""
|
1046 |
+
args_list = list(args)
|
1047 |
+
user_message = args_list[0]
|
1048 |
+
input1 = args_list[1]
|
1049 |
+
context1 = args_list[2]
|
1050 |
+
if input1 and not user_message.endswith(':'):
|
1051 |
+
user_message1 = user_message + ":" + input1
|
1052 |
+
elif input1:
|
1053 |
+
user_message1 = user_message + input1
|
1054 |
+
else:
|
1055 |
+
user_message1 = user_message
|
1056 |
+
if sanitize_user_prompt:
|
1057 |
+
from better_profanity import profanity
|
1058 |
+
user_message1 = profanity.censor(user_message1)
|
1059 |
|
1060 |
+
history = args_list[-1]
|
1061 |
+
if undo and history:
|
1062 |
+
history.pop()
|
1063 |
+
args_list = args_list[:-1] # FYI, even if unused currently
|
1064 |
+
if history is None:
|
1065 |
+
if not model2:
|
1066 |
+
# no need to complain so often unless model1
|
1067 |
print("Bad history, fix for now", flush=True)
|
1068 |
+
history = []
|
1069 |
+
# ensure elements not mixed across models as output,
|
1070 |
+
# even if input is currently same source
|
1071 |
+
history = history.copy()
|
1072 |
+
if undo:
|
1073 |
+
return history
|
1074 |
+
else:
|
1075 |
+
# FIXME: compare, same history for now
|
1076 |
+
return history + [[user_message1, None]]
|
1077 |
+
|
1078 |
+
def bot(*args, retry=False):
|
1079 |
+
"""
|
1080 |
+
bot that consumes history for user input
|
1081 |
+
instruction (from input_list) itself is not consumed by bot
|
1082 |
+
:param args:
|
1083 |
+
:param retry:
|
1084 |
+
:return:
|
1085 |
+
"""
|
1086 |
+
args_list = list(args).copy()
|
1087 |
+
history = args_list[-1] # model_state is -2
|
1088 |
+
if retry and history:
|
1089 |
+
history.pop()
|
1090 |
+
if not history:
|
1091 |
+
print("No history", flush=True)
|
1092 |
+
return
|
1093 |
+
# ensure output will be unique to models
|
1094 |
+
history = history.copy()
|
1095 |
+
instruction1 = history[-1][0]
|
1096 |
+
context1 = ''
|
1097 |
+
if kwargs['chat_history'] > 0:
|
1098 |
+
prompt_type_arg_id = eval_func_param_names.index('prompt_type')
|
1099 |
+
prompt_type1 = args_list[prompt_type_arg_id]
|
1100 |
+
chat_arg_id = eval_func_param_names.index('chat')
|
1101 |
+
chat1 = args_list[chat_arg_id]
|
1102 |
context1 = ''
|
1103 |
+
for histi in range(len(history) - 1):
|
1104 |
+
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
|
1105 |
+
context1 += generate_prompt(data_point, prompt_type1, chat1, reduced=True)[0].replace(
|
1106 |
+
'<br>', '\n')
|
1107 |
+
if not context1.endswith('\n'):
|
1108 |
+
context1 += '\n'
|
1109 |
+
if context1 and not context1.endswith('\n'):
|
1110 |
+
context1 += '\n' # ensure if terminates abruptly, then human continues on next line
|
1111 |
+
args_list[0] = instruction1 # override original instruction with history from user
|
1112 |
+
# only include desired chat history
|
1113 |
+
args_list[2] = context1[-kwargs['chat_history']:]
|
1114 |
+
model_state1 = args_list[-2]
|
1115 |
+
if model_state1[0] is None or model_state1[0] == no_model_str:
|
1116 |
+
return
|
1117 |
+
args_list = args_list[:-2]
|
1118 |
+
fun1 = partial(evaluate,
|
1119 |
+
model_state1,
|
1120 |
+
**kwargs_evaluate)
|
1121 |
+
try:
|
1122 |
+
for output in fun1(*tuple(args_list)):
|
1123 |
+
bot_message = output
|
1124 |
+
history[-1][1] = bot_message
|
|
|
|
|
|
|
1125 |
yield history
|
1126 |
+
except StopIteration:
|
1127 |
+
yield history
|
1128 |
+
except RuntimeError as e:
|
1129 |
+
if "generator raised StopIteration" in str(e):
|
1130 |
+
# assume last entry was bad, undo
|
1131 |
+
history.pop()
|
|
|
|
|
|
|
1132 |
yield history
|
1133 |
+
raise
|
1134 |
+
except Exception as e:
|
1135 |
+
# put error into user input
|
1136 |
+
history[-1][0] = "Exception: %s" % str(e)
|
1137 |
+
yield history
|
1138 |
+
raise
|
1139 |
+
return
|
1140 |
+
|
1141 |
+
# NORMAL MODEL
|
1142 |
+
user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
|
1143 |
+
inputs=inputs_list + [text_output],
|
1144 |
+
outputs=text_output,
|
1145 |
+
)
|
1146 |
+
bot_args = dict(fn=bot,
|
1147 |
+
inputs=inputs_list + [model_state] + [text_output],
|
1148 |
+
outputs=text_output,
|
1149 |
+
)
|
1150 |
+
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
1151 |
+
inputs=inputs_list + [model_state] + [text_output],
|
1152 |
+
outputs=text_output,
|
1153 |
+
)
|
1154 |
+
undo_user_args = dict(fn=functools.partial(user, undo=True),
|
1155 |
+
inputs=inputs_list + [text_output],
|
1156 |
+
outputs=text_output,
|
1157 |
+
)
|
1158 |
|
1159 |
+
# MODEL2
|
1160 |
+
user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt'], model2=True),
|
1161 |
+
inputs=inputs_list + [text_output2],
|
1162 |
+
outputs=text_output2,
|
1163 |
+
)
|
1164 |
+
bot_args2 = dict(fn=bot,
|
1165 |
+
inputs=inputs_list + [model_state2] + [text_output2],
|
1166 |
+
outputs=text_output2,
|
1167 |
+
)
|
1168 |
+
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
|
1169 |
+
inputs=inputs_list + [model_state2] + [text_output2],
|
1170 |
+
outputs=text_output2,
|
1171 |
+
)
|
1172 |
+
undo_user_args2 = dict(fn=functools.partial(user, undo=True),
|
1173 |
+
inputs=inputs_list + [text_output2],
|
1174 |
+
outputs=text_output2,
|
1175 |
+
)
|
1176 |
+
|
1177 |
+
def clear_instruct():
|
1178 |
+
return gr.Textbox.update(value='')
|
1179 |
+
|
1180 |
+
if kwargs['auto_score']:
|
1181 |
+
# in case 2nd model, consume instruction first, so can clear quickly
|
1182 |
+
# bot doesn't consume instruction itself, just history from user, so why works
|
1183 |
+
submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction') \
|
1184 |
+
.then(**user_args2, queue=stream_output, api_name='instruction2') \
|
1185 |
+
.then(clear_instruct, None, instruction) \
|
1186 |
+
.then(**bot_args, api_name='instruction_bot') \
|
1187 |
+
.then(**score_args, api_name='instruction_bot_score') \
|
1188 |
+
.then(**bot_args2, api_name='instruction_bot2') \
|
1189 |
+
.then(**score_args2, api_name='instruction_bot_score2') \
|
1190 |
+
.then(clear_torch_cache)
|
1191 |
+
submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit') \
|
1192 |
+
.then(**user_args2, queue=stream_output, api_name='submit2') \
|
1193 |
+
.then(**bot_args, api_name='submit_bot') \
|
1194 |
+
.then(clear_instruct, None, instruction) \
|
1195 |
+
.then(**score_args, api_name='submit_bot_score') \
|
1196 |
+
.then(**bot_args2, api_name='submit_bot2') \
|
1197 |
+
.then(**score_args2, api_name='submit_bot_score2') \
|
1198 |
+
.then(clear_torch_cache)
|
1199 |
+
submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry') \
|
1200 |
+
.then(**user_args2, queue=stream_output, api_name='retry2') \
|
1201 |
+
.then(clear_instruct, None, instruction) \
|
1202 |
+
.then(**retry_bot_args, api_name='retry_bot') \
|
1203 |
+
.then(**score_args, api_name='retry_bot_score') \
|
1204 |
+
.then(**retry_bot_args2, api_name='retry_bot2') \
|
1205 |
+
.then(**score_args2, api_name='retry_bot_score2') \
|
1206 |
+
.then(clear_torch_cache)
|
1207 |
+
submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo') \
|
1208 |
+
.then(**score_args, api_name='undo_score') \
|
1209 |
+
.then(**undo_user_args2, queue=stream_output, api_name='undo2') \
|
1210 |
+
.then(**score_args2, api_name='undo_score2') \
|
1211 |
+
.then(clear_instruct, None, instruction)
|
1212 |
+
else:
|
1213 |
+
submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction') \
|
1214 |
+
.then(**user_args2, queue=stream_output, api_name='instruction2') \
|
1215 |
+
.then(clear_instruct, None, instruction) \
|
1216 |
+
.then(**bot_args, api_name='instruction_bot') \
|
1217 |
+
.then(**bot_args2, api_name='instruction_bot2') \
|
1218 |
+
.then(clear_torch_cache)
|
1219 |
+
submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit') \
|
1220 |
+
.then(**user_args2, queue=stream_output, api_name='submit2') \
|
1221 |
+
.then(clear_instruct, None, instruction) \
|
1222 |
+
.then(**bot_args, api_name='submit_bot') \
|
1223 |
+
.then(**bot_args2, api_name='submit_bot2') \
|
1224 |
+
.then(clear_torch_cache)
|
1225 |
+
submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry') \
|
1226 |
+
.then(**user_args2, queue=stream_output, api_name='retry2') \
|
1227 |
+
.then(clear_instruct, None, instruction) \
|
1228 |
+
.then(**retry_bot_args, api_name='retry_bot') \
|
1229 |
+
.then(**retry_bot_args2, api_name='retry_bot2') \
|
1230 |
+
.then(clear_torch_cache)
|
1231 |
+
submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo') \
|
1232 |
+
.then(**undo_user_args2, queue=stream_output, api_name='undo2')
|
1233 |
+
|
1234 |
+
# does both models
|
1235 |
+
clear.click(lambda: None, None, text_output, queue=False, api_name='clear') \
|
1236 |
+
.then(lambda: None, None, text_output2, queue=False, api_name='clear2')
|
1237 |
+
# FIXME: compare
|
1238 |
+
submit_event_nochat = submit_nochat.click(fun, inputs=[model_state] + inputs_list,
|
1239 |
+
outputs=text_output_nochat, api_name='submit_nochat') \
|
1240 |
+
.then(**score_args_nochat, api_name='instruction_bot_score_nochat') \
|
1241 |
+
.then(clear_torch_cache)
|
1242 |
+
|
1243 |
+
def load_model(model_name, lora_weights, model_state_old, prompt_type_old, load_8bit, infer_devices, gpu_id):
|
1244 |
# ensure old model removed from GPU memory
|
1245 |
if kwargs['debug']:
|
1246 |
print("Pre-switch pre-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
|
|
1265 |
clear_torch_cache()
|
1266 |
if kwargs['debug']:
|
1267 |
print("Pre-switch post-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
1268 |
+
|
1269 |
+
if model_name is None or model_name == no_model_str:
|
1270 |
+
# no-op if no model, just free memory
|
1271 |
+
# no detranscribe needed for model, never go into evaluate
|
1272 |
+
lora_weights = no_lora_str
|
1273 |
+
return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
|
1274 |
+
|
1275 |
+
all_kwargs1 = all_kwargs.copy()
|
1276 |
+
all_kwargs1['base_model'] = model_name.strip()
|
1277 |
+
all_kwargs1['load_8bit'] = load_8bit
|
1278 |
+
all_kwargs1['infer_devices'] = infer_devices
|
1279 |
+
all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
|
1280 |
model_lower = model_name.strip().lower()
|
1281 |
if model_lower in inv_prompt_type_to_model_lower:
|
1282 |
prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
|
1283 |
else:
|
1284 |
prompt_type1 = prompt_type_old
|
1285 |
|
1286 |
+
# detranscribe
|
1287 |
+
if lora_weights == no_lora_str:
|
1288 |
+
lora_weights = ''
|
1289 |
+
|
1290 |
+
all_kwargs1['lora_weights'] = lora_weights.strip()
|
1291 |
+
model1, tokenizer1, device1 = get_model(**all_kwargs1)
|
1292 |
clear_torch_cache()
|
1293 |
|
1294 |
if kwargs['debug']:
|
1295 |
print("Post-switch GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
|
1296 |
+
return [model1, tokenizer1, device1, model_name], model_name, lora_weights, prompt_type1
|
|
|
|
|
|
|
1297 |
|
1298 |
def dropdown_prompt_type_list(x):
|
1299 |
return gr.Dropdown.update(value=x)
|
|
|
1302 |
return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
|
1303 |
|
1304 |
load_model_args = dict(fn=load_model,
|
1305 |
+
inputs=[model_choice, lora_choice, model_state, prompt_type,
|
1306 |
+
model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu],
|
1307 |
outputs=[model_state, model_used, lora_used, prompt_type])
|
1308 |
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
|
1309 |
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
|
1310 |
if not is_public:
|
1311 |
load_model_event = load_model_button.click(**load_model_args) \
|
1312 |
+
.then(**prompt_update_args) \
|
1313 |
+
.then(**chatbot_update_args) \
|
1314 |
+
.then(clear_torch_cache)
|
1315 |
+
|
1316 |
+
load_model_args2 = dict(fn=load_model,
|
1317 |
+
inputs=[model_choice2, lora_choice2, model_state2, prompt_type2,
|
1318 |
+
model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2],
|
1319 |
+
outputs=[model_state2, model_used2, lora_used2, prompt_type2])
|
1320 |
+
prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
|
1321 |
+
chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
|
1322 |
+
if not is_public:
|
1323 |
+
load_model_event2 = load_model_button2.click(**load_model_args2) \
|
1324 |
+
.then(**prompt_update_args2) \
|
1325 |
+
.then(**chatbot_update_args2) \
|
1326 |
+
.then(clear_torch_cache)
|
1327 |
|
1328 |
def dropdown_model_list(list0, x):
|
1329 |
new_state = [list0[0] + [x]]
|
1330 |
new_options = [*new_state[0]]
|
1331 |
+
return gr.Dropdown.update(value=x, choices=new_options), \
|
1332 |
+
gr.Dropdown.update(value=x, choices=new_options), \
|
1333 |
+
'', new_state
|
1334 |
|
1335 |
add_model_event = add_model_button.click(fn=dropdown_model_list,
|
1336 |
inputs=[model_options_state, new_model],
|
1337 |
+
outputs=[model_choice, model_choice2, new_model, model_options_state])
|
1338 |
|
1339 |
+
def dropdown_lora_list(list0, x, model_used1, lora_used1, model_used2, lora_used2):
|
1340 |
new_state = [list0[0] + [x]]
|
1341 |
new_options = [*new_state[0]]
|
1342 |
+
# don't switch drop-down to added lora if already have model loaded
|
1343 |
+
x1 = x if model_used1 == no_model_str else lora_used1
|
1344 |
+
x2 = x if model_used2 == no_model_str else lora_used2
|
1345 |
+
return gr.Dropdown.update(value=x1, choices=new_options), \
|
1346 |
+
gr.Dropdown.update(value=x2, choices=new_options), \
|
1347 |
+
'', new_state
|
1348 |
|
1349 |
add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
|
1350 |
+
inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2, lora_used2],
|
1351 |
+
outputs=[lora_choice, lora_choice2, new_lora, lora_options_state])
|
1352 |
|
1353 |
go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go") \
|
1354 |
.then(lambda: gr.update(visible=True), None, normal_block) \
|
1355 |
.then(**load_model_args).then(**prompt_update_args)
|
1356 |
|
1357 |
+
def compare_textbox_fun(x):
|
1358 |
+
return gr.Textbox.update(visible=x)
|
1359 |
+
|
1360 |
+
def compare_column_fun(x):
|
1361 |
+
return gr.Column.update(visible=x)
|
1362 |
+
|
1363 |
+
def compare_prompt_fun(x):
|
1364 |
+
return gr.Dropdown.update(visible=x)
|
1365 |
+
|
1366 |
+
compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2, api_name="compare_checkbox") \
|
1367 |
+
.then(compare_column_fun, compare_checkbox, col_model2) \
|
1368 |
+
.then(compare_prompt_fun, compare_checkbox, prompt_type2) \
|
1369 |
+
.then(compare_textbox_fun, compare_checkbox, score_text2)
|
1370 |
+
# FIXME: add score_res2 in condition, but do better
|
1371 |
+
|
1372 |
# callback for logging flagged input/output
|
1373 |
callback.setup(inputs_list + [text_output], "flagged_data_points")
|
1374 |
flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
|
1375 |
api_name='flag')
|
1376 |
+
flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
|
1377 |
+
api_name='flag_nochat')
|
1378 |
|
1379 |
def get_system_info():
|
1380 |
return gr.Textbox.update(value=system_info_print())
|
1381 |
|
1382 |
system_event = system_btn.click(get_system_info, outputs=system_text, api_name='system_info')
|
1383 |
|
1384 |
+
# don't pass text_output, don't want to clear output, just stop it
|
1385 |
+
# FIXME: have to click once to stop output and second time to stop GPUs going
|
1386 |
+
stop_btn.click(lambda: None, None, None,
|
1387 |
+
cancels=[submit_event_nochat, submit_event, submit_event2, submit_event3],
|
1388 |
+
queue=False, api_name='stop').then(clear_torch_cache)
|
|
|
1389 |
|
1390 |
demo.queue(concurrency_count=1)
|
1391 |
favicon_path = "h2o-logo.svg"
|
|
|
1396 |
|
1397 |
|
1398 |
input_args_list = ['model_state']
|
1399 |
+
inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
|
1400 |
|
1401 |
|
1402 |
def get_inputs_list(inputs_dict, model_lower):
|
1403 |
+
"""
|
1404 |
+
map gradio objects in locals() to inputs for evaluate().
|
1405 |
+
:param inputs_dict:
|
1406 |
+
:param model_lower:
|
1407 |
+
:return:
|
1408 |
+
"""
|
1409 |
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
1410 |
inputs_list = []
|
1411 |
for k in inputs_list_names:
|
|
|
1420 |
return inputs_list
|
1421 |
|
1422 |
|
|
|
|
|
|
|
1423 |
eval_func_param_names = ['instruction',
|
1424 |
'iinput',
|
1425 |
'context',
|
|
|
1436 |
'repetition_penalty',
|
1437 |
'num_return_sequences',
|
1438 |
'do_sample',
|
1439 |
+
'chat',
|
1440 |
+
'instruction_nochat',
|
1441 |
+
'iinput_nochat',
|
1442 |
]
|
1443 |
|
1444 |
|
|
|
1461 |
repetition_penalty,
|
1462 |
num_return_sequences,
|
1463 |
do_sample,
|
1464 |
+
chat,
|
1465 |
+
instruction_nochat,
|
1466 |
+
iinput_nochat,
|
1467 |
# END NOTE: Examples must have same order of parameters
|
1468 |
src_lang=None,
|
1469 |
tgt_lang=None,
|
1470 |
debug=False,
|
1471 |
save_dir=None,
|
|
|
1472 |
hard_stop_list=None,
|
1473 |
sanitize_bot_response=True,
|
1474 |
model_state0=None,
|
|
|
1477 |
if debug:
|
1478 |
locals_dict = locals().copy()
|
1479 |
locals_dict.pop('model_state', None)
|
1480 |
+
locals_dict.pop('model_state0', None)
|
1481 |
print(locals_dict)
|
1482 |
|
1483 |
no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
|
1484 |
|
1485 |
+
if model_state0 is None:
|
1486 |
+
# e.g. for no gradio case, set dummy value, else should be set
|
1487 |
+
model_state0 = [None, None, None, None]
|
1488 |
+
|
1489 |
if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
|
1490 |
# try to free-up original model (i.e. list was passed as reference)
|
1491 |
if model_state0 is not None and model_state0[0] is not None:
|
|
|
1502 |
else:
|
1503 |
raise AssertionError(no_model_msg)
|
1504 |
|
1505 |
+
if base_model is None:
|
1506 |
+
raise AssertionError(no_model_msg)
|
1507 |
+
|
1508 |
assert base_model.strip(), no_model_msg
|
1509 |
assert model, "Model is missing"
|
1510 |
assert tokenizer, "Tokenizer is missing"
|
1511 |
|
1512 |
+
# choose chat or non-chat mode
|
1513 |
+
if not chat:
|
1514 |
+
instruction = instruction_nochat
|
1515 |
+
iinput = iinput_nochat
|
1516 |
+
|
1517 |
data_point = dict(context=context, instruction=instruction, input=iinput)
|
1518 |
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
1519 |
prompt = prompter.generate_prompt(data_point)
|
|
|
1548 |
elif prompt_type == 'instruct_vicuna':
|
1549 |
# even below is not enough, generic strings and many ways to encode
|
1550 |
stop_words = [
|
1551 |
+
'### Human:',
|
1552 |
+
"""
|
1553 |
### Human:""",
|
1554 |
+
"""
|
1555 |
### Human:
|
1556 |
""",
|
1557 |
+
'### Assistant:',
|
1558 |
+
"""
|
1559 |
### Assistant:""",
|
1560 |
+
"""
|
1561 |
### Assistant:
|
1562 |
""",
|
1563 |
]
|
|
|
1575 |
if tokenizer.pad_token:
|
1576 |
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
1577 |
# handle fake \n added
|
1578 |
+
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
1579 |
# build stopper
|
1580 |
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
|
1581 |
else:
|
|
|
1673 |
traceback.print_exc()
|
1674 |
clear_torch_cache()
|
1675 |
return
|
1676 |
+
except (Exception, RuntimeError) as e:
|
1677 |
if 'Expected all tensors to be on the same device' in str(e) or \
|
1678 |
'expected scalar type Half but found Float' in str(e) or \
|
1679 |
+
'probability tensor contains either' in str(e) or \
|
1680 |
+
'cublasLt ran into an error!' in str(e):
|
1681 |
print(
|
1682 |
"GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
|
1683 |
flush=True)
|
1684 |
traceback.print_exc()
|
1685 |
clear_torch_cache()
|
1686 |
+
if raise_generate_gpu_exceptions:
|
1687 |
+
raise
|
1688 |
return
|
1689 |
else:
|
1690 |
raise
|
|
|
1795 |
else:
|
1796 |
prompt_type = ''
|
1797 |
examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
|
1798 |
+
stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1,
|
1799 |
+
False]]
|
1800 |
task_info = "No task"
|
1801 |
if prompt_type == 'instruct':
|
1802 |
task_info = "Answer question or follow imperative as instruction with optionally input."
|
|
|
1874 |
src_lang = "English"
|
1875 |
tgt_lang = "Russian"
|
1876 |
|
1877 |
+
# adjust examples if non-chat mode
|
1878 |
+
if not chat:
|
1879 |
+
# move to correct position
|
1880 |
+
for example in examples:
|
1881 |
+
example[eval_func_param_names.index('instruction_nochat')] = example[
|
1882 |
+
eval_func_param_names.index('instruction')]
|
1883 |
+
example[eval_func_param_names.index('instruction')] = ''
|
1884 |
+
|
1885 |
+
example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
|
1886 |
+
example[eval_func_param_names.index('iinput')] = ''
|
1887 |
+
|
1888 |
return placeholder_instruction, placeholder_input, \
|
1889 |
stream_output, show_examples, \
|
1890 |
prompt_type, temperature, top_p, top_k, num_beams, \
|
client_test.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
"""
|
2 |
-
Client test.
|
3 |
|
4 |
-
Run server
|
5 |
|
6 |
-
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b
|
7 |
|
8 |
NOTE: For private models, add --use-auth_token=True
|
9 |
|
@@ -17,7 +17,6 @@ python client_test.py
|
|
17 |
|
18 |
debug = False
|
19 |
|
20 |
-
import time
|
21 |
import os
|
22 |
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
23 |
from gradio_client import Client
|
@@ -26,8 +25,8 @@ client = Client("http://localhost:7860")
|
|
26 |
if debug:
|
27 |
print(client.view_api(all_endpoints=True))
|
28 |
|
29 |
-
instruction =
|
30 |
-
iinput = ''
|
31 |
context = ''
|
32 |
# streaming output is supported, loops over and outputs each generation in streaming mode
|
33 |
# but leave stream_output=False for simple input/output mode
|
@@ -37,19 +36,17 @@ temperature = 0.1
|
|
37 |
top_p = 0.75
|
38 |
top_k = 40
|
39 |
num_beams = 1
|
40 |
-
max_new_tokens =
|
41 |
min_new_tokens = 0
|
42 |
early_stopping = False
|
43 |
-
max_time =
|
44 |
repetition_penalty = 1.0
|
45 |
num_return_sequences = 1
|
46 |
do_sample = True
|
47 |
-
|
48 |
-
# CHOOSE: must match server
|
49 |
-
# NOTE chat mode works through files on gradio
|
50 |
-
# and client currently would have to work through those files
|
51 |
-
# in tmp, so not best for client. So default to False
|
52 |
chat = False
|
|
|
|
|
53 |
|
54 |
|
55 |
def test_client_basic():
|
@@ -68,43 +65,18 @@ def test_client_basic():
|
|
68 |
max_time,
|
69 |
repetition_penalty,
|
70 |
num_return_sequences,
|
71 |
-
do_sample
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
import json
|
84 |
-
foofile = '/tmp/foo.json'
|
85 |
-
with open(foofile, 'wt') as f:
|
86 |
-
json.dump([['', None]], f)
|
87 |
-
args += [foofile]
|
88 |
-
if not stream_output:
|
89 |
-
for res in client.predict(
|
90 |
-
*tuple(args),
|
91 |
-
api_name=api_name,
|
92 |
-
):
|
93 |
-
print(res)
|
94 |
-
res_file = client.predict(*tuple(args), api_name='/instruction_bot')
|
95 |
-
res = json.load(open(res_file, "rt"))[-1][-1]
|
96 |
-
print(md_to_text(res))
|
97 |
-
else:
|
98 |
-
print("streaming instruction_bot", flush=True)
|
99 |
-
job = client.submit(*tuple(args), api_name='/instruction_bot')
|
100 |
-
while not job.done():
|
101 |
-
outputs_list = job.communicator.job.outputs
|
102 |
-
if outputs_list:
|
103 |
-
res_file = job.communicator.job.outputs[-1]
|
104 |
-
res = json.load(open(res_file, "rt"))[-1][-1]
|
105 |
-
print(md_to_text(res))
|
106 |
-
time.sleep(0.1)
|
107 |
-
print(job.outputs())
|
108 |
|
109 |
|
110 |
import markdown # pip install markdown
|
|
|
1 |
"""
|
2 |
+
Client test.
|
3 |
|
4 |
+
Run server:
|
5 |
|
6 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b
|
7 |
|
8 |
NOTE: For private models, add --use-auth_token=True
|
9 |
|
|
|
17 |
|
18 |
debug = False
|
19 |
|
|
|
20 |
import os
|
21 |
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
22 |
from gradio_client import Client
|
|
|
25 |
if debug:
|
26 |
print(client.view_api(all_endpoints=True))
|
27 |
|
28 |
+
instruction = '' # only for chat=True
|
29 |
+
iinput = '' # only for chat=True
|
30 |
context = ''
|
31 |
# streaming output is supported, loops over and outputs each generation in streaming mode
|
32 |
# but leave stream_output=False for simple input/output mode
|
|
|
36 |
top_p = 0.75
|
37 |
top_k = 40
|
38 |
num_beams = 1
|
39 |
+
max_new_tokens = 50
|
40 |
min_new_tokens = 0
|
41 |
early_stopping = False
|
42 |
+
max_time = 20
|
43 |
repetition_penalty = 1.0
|
44 |
num_return_sequences = 1
|
45 |
do_sample = True
|
46 |
+
# only these 2 below used if pass chat=False
|
|
|
|
|
|
|
|
|
47 |
chat = False
|
48 |
+
instruction_nochat = "Who are you?"
|
49 |
+
iinput_nochat = ''
|
50 |
|
51 |
|
52 |
def test_client_basic():
|
|
|
65 |
max_time,
|
66 |
repetition_penalty,
|
67 |
num_return_sequences,
|
68 |
+
do_sample,
|
69 |
+
chat,
|
70 |
+
instruction_nochat,
|
71 |
+
iinput_nochat,
|
72 |
+
]
|
73 |
+
api_name = '/submit_nochat'
|
74 |
+
res = client.predict(
|
75 |
+
*tuple(args),
|
76 |
+
api_name=api_name,
|
77 |
+
)
|
78 |
+
res_dict = dict(instruction_nochat=instruction_nochat, iinput_nochat=iinput_nochat, response=md_to_text(res))
|
79 |
+
print(res_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
|
82 |
import markdown # pip install markdown
|
finetune.py
CHANGED
@@ -121,7 +121,7 @@ def train(
|
|
121 |
save_code: bool = False,
|
122 |
run_id: int = None,
|
123 |
|
124 |
-
base_model: str = 'h2oai/h2ogpt-oig-oasst1-
|
125 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
|
126 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
|
127 |
# base_model: str = 'EleutherAI/gpt-neox-20b',
|
@@ -810,7 +810,7 @@ Current Time: {}
|
|
810 |
|
811 |
|
812 |
def generate_prompt(data_point, prompt_type, chat, reduced):
|
813 |
-
context = data_point.get('context')
|
814 |
if context is None:
|
815 |
context = ''
|
816 |
instruction = data_point.get('instruction')
|
|
|
121 |
save_code: bool = False,
|
122 |
run_id: int = None,
|
123 |
|
124 |
+
base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6.9b',
|
125 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
|
126 |
# base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
|
127 |
# base_model: str = 'EleutherAI/gpt-neox-20b',
|
|
|
810 |
|
811 |
|
812 |
def generate_prompt(data_point, prompt_type, chat, reduced):
|
813 |
+
context = data_point.get('context')
|
814 |
if context is None:
|
815 |
context = ''
|
816 |
instruction = data_point.get('instruction')
|
utils.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
-
import contextlib
|
2 |
import os
|
3 |
import gc
|
4 |
import random
|
5 |
-
import shutil
|
6 |
import time
|
7 |
import traceback
|
8 |
import zipfile
|
9 |
-
|
10 |
import filelock
|
11 |
import numpy as np
|
12 |
import pandas as pd
|
@@ -95,17 +93,22 @@ def system_info_print():
|
|
95 |
return "Error: %s" % str(e)
|
96 |
|
97 |
|
98 |
-
def zip_data(root_dirs=None,
|
99 |
try:
|
100 |
-
return _zip_data(
|
101 |
except Exception as e:
|
102 |
traceback.print_exc()
|
103 |
print('Exception in zipping: %s' % str(e))
|
104 |
|
105 |
|
106 |
-
def _zip_data(root_dirs=None,
|
|
|
|
|
|
|
|
|
107 |
assert root_dirs is not None
|
108 |
-
|
|
|
109 |
for root_dir in root_dirs:
|
110 |
if root_dir is None:
|
111 |
continue
|
@@ -115,7 +118,7 @@ def _zip_data(root_dirs=None, zip_path='data.zip', base_dir='./'):
|
|
115 |
assert os.path.exists(file_to_archive)
|
116 |
path_to_archive = os.path.relpath(file_to_archive, base_dir)
|
117 |
expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
|
118 |
-
return
|
119 |
|
120 |
|
121 |
def save_generate_output(output=None, base_model=None, save_dir=None):
|
|
|
|
|
1 |
import os
|
2 |
import gc
|
3 |
import random
|
|
|
4 |
import time
|
5 |
import traceback
|
6 |
import zipfile
|
7 |
+
from datetime import datetime
|
8 |
import filelock
|
9 |
import numpy as np
|
10 |
import pandas as pd
|
|
|
93 |
return "Error: %s" % str(e)
|
94 |
|
95 |
|
96 |
+
def zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
97 |
try:
|
98 |
+
return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
|
99 |
except Exception as e:
|
100 |
traceback.print_exc()
|
101 |
print('Exception in zipping: %s' % str(e))
|
102 |
|
103 |
|
104 |
+
def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
105 |
+
if zip_file is None:
|
106 |
+
datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
|
107 |
+
host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
|
108 |
+
zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
|
109 |
assert root_dirs is not None
|
110 |
+
|
111 |
+
with zipfile.ZipFile(zip_file, "w") as expt_zip:
|
112 |
for root_dir in root_dirs:
|
113 |
if root_dir is None:
|
114 |
continue
|
|
|
118 |
assert os.path.exists(file_to_archive)
|
119 |
path_to_archive = os.path.relpath(file_to_archive, base_dir)
|
120 |
expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
|
121 |
+
return zip_file
|
122 |
|
123 |
|
124 |
def save_generate_output(output=None, base_model=None, save_dir=None):
|