|
|
|
import argparse |
|
import os |
|
import os.path as osp |
|
import re |
|
import sys |
|
|
|
import torch |
|
from huggingface_hub import snapshot_download |
|
from peft import PeftModel |
|
from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, |
|
BitsAndBytesConfig, SiglipImageProcessor, |
|
SiglipVisionModel, GenerationConfig) |
|
from transformers.generation.streamers import TextStreamer |
|
|
|
from xtuner.dataset.utils import expand2square, load_image |
|
from xtuner.model.utils import prepare_inputs_labels_for_multimodal |
|
from xtuner.tools.utils import get_stop_criteria |
|
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, |
|
PROMPT_TEMPLATE, SYSTEM_TEMPLATE) |
|
|
|
TORCH_DTYPE_MAP = dict( |
|
fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') |
|
|
|
|
|
def remove_prefix(state_dict, prefix): |
|
new_state_dict = {} |
|
for key, value in state_dict.items(): |
|
if key.startswith(prefix): |
|
new_key = key[len(prefix):] |
|
new_state_dict[new_key] = value |
|
else: |
|
new_state_dict[key] = value |
|
return new_state_dict |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Chat with a HF model') |
|
parser.add_argument( |
|
'model_name_or_path', help='Hugging Face model name or path') |
|
adapter_group = parser.add_mutually_exclusive_group() |
|
adapter_group.add_argument( |
|
'--adapter', default=None, help='adapter name or path') |
|
adapter_group.add_argument( |
|
'--llava', default=None, help='llava name or path') |
|
parser.add_argument( |
|
'--visual-encoder', default=None, help='visual encoder name or path') |
|
parser.add_argument( |
|
'--visual-select-layer', default=-2, help='visual select layer') |
|
parser.add_argument('--image', default=None, help='image') |
|
parser.add_argument( |
|
'--torch-dtype', |
|
default='fp16', |
|
choices=TORCH_DTYPE_MAP.keys(), |
|
help='Override the default `torch.dtype` and load the model under ' |
|
'a specific `dtype`.') |
|
parser.add_argument( |
|
'--prompt-template', |
|
choices=PROMPT_TEMPLATE.keys(), |
|
default=None, |
|
help='Specify a prompt template') |
|
system_group = parser.add_mutually_exclusive_group() |
|
system_group.add_argument( |
|
'--system', default=None, help='Specify the system text') |
|
system_group.add_argument( |
|
'--system-template', |
|
choices=SYSTEM_TEMPLATE.keys(), |
|
default=None, |
|
help='Specify a system template') |
|
parser.add_argument( |
|
'--bits', |
|
type=int, |
|
choices=[4, 8, None], |
|
default=None, |
|
help='LLM bits') |
|
parser.add_argument( |
|
'--bot-name', type=str, default='BOT', help='Name for Bot') |
|
parser.add_argument( |
|
'--with-plugins', |
|
nargs='+', |
|
choices=['calculate', 'solve', 'search'], |
|
help='Specify plugins to use') |
|
parser.add_argument( |
|
'--no-streamer', action='store_true', help='Whether to with streamer') |
|
parser.add_argument( |
|
'--lagent', action='store_true', help='Whether to use lagent') |
|
parser.add_argument( |
|
'--stop-words', nargs='+', type=str, default=[], help='Stop words') |
|
parser.add_argument( |
|
'--offload-folder', |
|
default=None, |
|
help='The folder in which to offload the model weights (or where the ' |
|
'model weights are already offloaded).') |
|
parser.add_argument( |
|
'--max-new-tokens', |
|
type=int, |
|
default=2048, |
|
help='Maximum number of new tokens allowed in generated text') |
|
parser.add_argument( |
|
'--temperature', |
|
type=float, |
|
default=0.1, |
|
help='The value used to modulate the next token probabilities.') |
|
parser.add_argument( |
|
'--top-k', |
|
type=int, |
|
default=40, |
|
help='The number of highest probability vocabulary tokens to ' |
|
'keep for top-k-filtering.') |
|
parser.add_argument( |
|
'--top-p', |
|
type=float, |
|
default=0.75, |
|
help='If set to float < 1, only the smallest set of most probable ' |
|
'tokens with probabilities that add up to top_p or higher are ' |
|
'kept for generation.') |
|
parser.add_argument( |
|
'--repetition-penalty', |
|
type=float, |
|
default=1.0, |
|
help='The parameter for repetition penalty. 1.0 means no penalty.') |
|
parser.add_argument( |
|
'--seed', |
|
type=int, |
|
default=0, |
|
help='Random seed for reproducible text generation') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def get_input(): |
|
"""Helper function for getting input from users.""" |
|
sentinel = '' |
|
result = None |
|
while result is None: |
|
print(('\ndouble enter to end input (EXIT: exit chat, ' |
|
'RESET: reset history) >>> '), |
|
end='') |
|
try: |
|
result = '\n'.join(iter(input, sentinel)) |
|
except UnicodeDecodeError: |
|
print('Invalid characters detected. Please enter again.') |
|
return result |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
torch.manual_seed(args.seed) |
|
|
|
|
|
quantization_config = None |
|
load_in_8bit = False |
|
if args.bits == 4: |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
load_in_8bit=False, |
|
llm_int8_threshold=6.0, |
|
llm_int8_has_fp16_weight=False, |
|
bnb_4bit_compute_dtype=torch.float16, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type='nf4') |
|
elif args.bits == 8: |
|
load_in_8bit = True |
|
model_kwargs = { |
|
'quantization_config': quantization_config, |
|
'load_in_8bit': load_in_8bit, |
|
'device_map': 'auto', |
|
'offload_folder': args.offload_folder, |
|
'trust_remote_code': True, |
|
'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype] |
|
} |
|
if args.lagent: |
|
from lagent.actions import ActionExecutor, GoogleSearch |
|
from lagent.agents import (CALL_PROTOCOL_CN, FORCE_STOP_PROMPT_CN, |
|
ReAct, ReActProtocol) |
|
from lagent.llms import HFTransformerCasualLM |
|
|
|
try: |
|
SERPER_API_KEY = os.environ['SERPER_API_KEY'] |
|
except Exception: |
|
print('Please obtain the `SERPER_API_KEY` from https://serper.dev ' |
|
'and set it using `export SERPER_API_KEY=xxx`.') |
|
sys.exit(1) |
|
|
|
model_kwargs.pop('trust_remote_code') |
|
llm = HFTransformerCasualLM( |
|
args.model_name_or_path, model_kwargs=model_kwargs) |
|
if args.adapter is not None: |
|
print(f'Loading adapter from {args.adapter}...') |
|
llm.model = PeftModel.from_pretrained( |
|
llm.model, |
|
args.adapter, |
|
offload_folder=args.offload_folder, |
|
trust_remote_code=True) |
|
search_tool = GoogleSearch(api_key=SERPER_API_KEY) |
|
chatbot = ReAct( |
|
llm=llm, |
|
action_executor=ActionExecutor(actions=[search_tool]), |
|
protocol=ReActProtocol( |
|
call_protocol=CALL_PROTOCOL_CN, |
|
force_stop=FORCE_STOP_PROMPT_CN)) |
|
while True: |
|
text = get_input() |
|
while text.strip() == 'RESET': |
|
print('Log: History responses have been removed!') |
|
chatbot._session_history = [] |
|
inputs = '' |
|
text = get_input() |
|
if text.strip() == 'EXIT': |
|
print('Log: Exit!') |
|
exit(0) |
|
response = chatbot.chat(text) |
|
print(response.response) |
|
else: |
|
if args.with_plugins is None: |
|
inner_thoughts_open = False |
|
calculate_open = False |
|
solve_open = False |
|
search_open = False |
|
else: |
|
assert args.prompt_template == args.system_template == 'moss_sft' |
|
from plugins import plugins_api |
|
inner_thoughts_open = True |
|
calculate_open = 'calculate' in args.with_plugins |
|
solve_open = 'solve' in args.with_plugins |
|
search_open = 'search' in args.with_plugins |
|
|
|
if calculate_open: |
|
from plugins import calculate |
|
if solve_open: |
|
from plugins import solve |
|
if search_open: |
|
from plugins import search |
|
|
|
llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, |
|
**model_kwargs) |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
args.model_name_or_path, |
|
trust_remote_code=True, |
|
encode_special_tokens=True) |
|
print(f'Load LLM from {args.model_name_or_path}') |
|
if args.adapter is not None: |
|
llm = PeftModel.from_pretrained( |
|
llm, |
|
args.adapter, |
|
offload_folder=args.offload_folder, |
|
trust_remote_code=True) |
|
print(f'Load adapter from {args.adapter}') |
|
if args.llava is not None: |
|
llava_path = snapshot_download( |
|
repo_id=args.llava) if not osp.isdir( |
|
args.llava) else args.llava |
|
|
|
|
|
if 'visual_encoder' in os.listdir(llava_path): |
|
assert args.visual_encoder is None, ( |
|
"Please don't specify the `--visual-encoder` since passed " |
|
'`--llava` contains a visual encoder!') |
|
visual_encoder_path = osp.join(llava_path, 'visual_encoder') |
|
else: |
|
assert args.visual_encoder is not None, ( |
|
'Please specify the `--visual-encoder`!') |
|
visual_encoder_path = args.visual_encoder |
|
visual_encoder = SiglipVisionModel.from_pretrained( |
|
visual_encoder_path, |
|
torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype]) |
|
image_processor = SiglipImageProcessor.from_pretrained( |
|
visual_encoder_path) |
|
print(f'Load visual_encoder from {visual_encoder_path}') |
|
|
|
|
|
if 'llm_adapter' in os.listdir(llava_path): |
|
adapter_path = osp.join(llava_path, 'llm_adapter') |
|
llm = PeftModel.from_pretrained( |
|
llm, |
|
adapter_path, |
|
offload_folder=args.offload_folder, |
|
trust_remote_code=True) |
|
print(f'Load LLM adapter from {args.llava}') |
|
if 'visual_encoder_adapter' in os.listdir(llava_path): |
|
adapter_path = osp.join(llava_path, 'visual_encoder_adapter') |
|
visual_encoder = PeftModel.from_pretrained( |
|
visual_encoder, |
|
adapter_path, |
|
offload_folder=args.offload_folder) |
|
print(f'Load visual_encoder adapter from {args.llava}') |
|
|
|
|
|
projector_path = osp.join(llava_path, 'projector') |
|
projector = AutoModel.from_pretrained( |
|
projector_path, |
|
torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype], |
|
trust_remote_code=True) |
|
print(f'Load projector from {args.llava}') |
|
|
|
projector.cuda() |
|
projector.eval() |
|
visual_encoder.cuda() |
|
visual_encoder.eval() |
|
|
|
llm.eval() |
|
|
|
if args.image is not None: |
|
image = load_image(args.image) |
|
image = expand2square( |
|
image, tuple(int(x * 255) for x in image_processor.image_mean)) |
|
image = image_processor.preprocess( |
|
image, return_tensors='pt')['pixel_values'][0] |
|
image = image.cuda().unsqueeze(0) |
|
visual_outputs = visual_encoder(image, output_hidden_states=True) |
|
pixel_values = projector( |
|
visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) |
|
|
|
stop_words = args.stop_words |
|
sep = '' |
|
if args.prompt_template: |
|
template = PROMPT_TEMPLATE[args.prompt_template] |
|
stop_words += template.get('STOP_WORDS', []) |
|
sep = template.get('SEP', '') |
|
stop_criteria = get_stop_criteria( |
|
tokenizer=tokenizer, stop_words=stop_words) |
|
|
|
if args.no_streamer: |
|
streamer = None |
|
else: |
|
streamer = TextStreamer(tokenizer, skip_prompt=True) |
|
|
|
gen_config = GenerationConfig( |
|
max_new_tokens=args.max_new_tokens, |
|
do_sample=args.temperature > 0, |
|
temperature=args.temperature, |
|
top_p=args.top_p, |
|
top_k=args.top_k, |
|
repetition_penalty=args.repetition_penalty, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id |
|
if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, |
|
) |
|
|
|
n_turn = 0 |
|
inputs = '' |
|
while True: |
|
text = get_input() |
|
while text.strip() == 'RESET': |
|
print('Log: History responses have been removed!') |
|
n_turn = 0 |
|
inputs = '' |
|
text = get_input() |
|
if text.strip() == 'EXIT': |
|
print('Log: Exit!') |
|
exit(0) |
|
|
|
if args.image is not None and n_turn == 0: |
|
text = DEFAULT_IMAGE_TOKEN + '\n' + text |
|
|
|
if args.prompt_template: |
|
prompt_text = '' |
|
template = PROMPT_TEMPLATE[args.prompt_template] |
|
if 'SYSTEM' in template and n_turn == 0: |
|
system_text = None |
|
if args.system_template is not None: |
|
system_text = SYSTEM_TEMPLATE[ |
|
args.system_template].format( |
|
round=n_turn + 1, bot_name=args.bot_name) |
|
elif args.system is not None: |
|
system_text = args.system |
|
if system_text is not None: |
|
prompt_text += template['SYSTEM'].format( |
|
system=system_text, |
|
round=n_turn + 1, |
|
bot_name=args.bot_name) |
|
prompt_text += template['INSTRUCTION'].format( |
|
input=text, round=n_turn + 1, bot_name=args.bot_name) |
|
if args.prompt_template == args.system_template == 'moss_sft': |
|
if not inner_thoughts_open: |
|
prompt_text.replace('- Inner thoughts: enabled.', |
|
'- Inner thoughts: disabled.') |
|
if not calculate_open: |
|
prompt_text.replace(('- Calculator: enabled. API: ' |
|
'Calculate(expression)'), |
|
'- Calculator: disabled.') |
|
if not solve_open: |
|
prompt_text.replace( |
|
'- Equation solver: enabled. API: Solve(equation)', |
|
'- Equation solver: disabled.') |
|
if not search_open: |
|
prompt_text.replace( |
|
'- Web search: enabled. API: Search(query)', |
|
'- Web search: disabled.') |
|
else: |
|
prompt_text = text |
|
inputs += prompt_text |
|
if args.image is None: |
|
if n_turn == 0: |
|
ids = tokenizer.encode(inputs, return_tensors='pt') |
|
else: |
|
ids = tokenizer.encode( |
|
inputs, return_tensors='pt', add_special_tokens=False) |
|
|
|
if args.with_plugins is not None: |
|
generate_output = llm.generate( |
|
inputs=ids.cuda(), |
|
generation_config=gen_config, |
|
streamer=streamer, |
|
stopping_criteria=stop_criteria).cpu() |
|
generate_output_text = tokenizer.decode( |
|
generate_output[0][len(ids[0]):]) |
|
if streamer is None: |
|
end = '' if generate_output_text[-1] == '\n' else '\n' |
|
print(generate_output_text, end=end) |
|
pattern = r'<\|Commands\|>:(.*?)<eoc>' |
|
command_text = ', '.join( |
|
re.findall(pattern, generate_output_text)) |
|
extent_text = plugins_api( |
|
command_text, |
|
calculate_open=calculate_open, |
|
solve_open=solve_open, |
|
search_open=search_open) |
|
end = '' if extent_text[-1] == '\n' else '\n' |
|
print(extent_text, end=end) |
|
extent_text_ids = tokenizer.encode( |
|
extent_text, |
|
return_tensors='pt', |
|
add_special_tokens=False) |
|
new_ids = torch.cat((generate_output, extent_text_ids), |
|
dim=1) |
|
|
|
generate_output = llm.generate( |
|
inputs=new_ids.cuda(), |
|
generation_config=gen_config, |
|
streamer=streamer, |
|
stopping_criteria=stop_criteria) |
|
if streamer is None: |
|
output_text = tokenizer.decode( |
|
generate_output[0][len(new_ids[0]):]) |
|
end = '' if output_text[-1] == '\n' else '\n' |
|
print(output_text, end=end) |
|
else: |
|
generate_output = llm.generate( |
|
inputs=ids.cuda(), |
|
generation_config=gen_config, |
|
streamer=streamer, |
|
stopping_criteria=stop_criteria) |
|
if streamer is None: |
|
output_text = tokenizer.decode( |
|
generate_output[0][len(ids[0]):]) |
|
end = '' if output_text[-1] == '\n' else '\n' |
|
print(output_text, end=end) |
|
inputs = tokenizer.decode(generate_output[0]) |
|
else: |
|
chunk_encode = [] |
|
for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): |
|
if idx == 0 and n_turn == 0: |
|
cur_encode = tokenizer.encode(chunk) |
|
else: |
|
cur_encode = tokenizer.encode( |
|
chunk, add_special_tokens=False) |
|
chunk_encode.append(cur_encode) |
|
assert len(chunk_encode) == 2 |
|
ids = [] |
|
for idx, cur_chunk_encode in enumerate(chunk_encode): |
|
ids.extend(cur_chunk_encode) |
|
if idx != len(chunk_encode) - 1: |
|
ids.append(IMAGE_TOKEN_INDEX) |
|
ids = torch.tensor(ids).cuda().unsqueeze(0) |
|
mm_inputs = prepare_inputs_labels_for_multimodal( |
|
llm=llm, input_ids=ids, pixel_values=pixel_values) |
|
|
|
generate_output = llm.generate( |
|
**mm_inputs, |
|
generation_config=gen_config, |
|
streamer=streamer, |
|
bos_token_id=tokenizer.bos_token_id, |
|
stopping_criteria=stop_criteria) |
|
if streamer is None: |
|
output_text = tokenizer.decode(generate_output[0]) |
|
end = '' if output_text[-1] == '\n' else '\n' |
|
print(output_text, end=end) |
|
inputs += tokenizer.decode(generate_output[0]) |
|
n_turn += 1 |
|
inputs += sep |
|
if len(generate_output[0]) >= args.max_new_tokens: |
|
print( |
|
'Remove the memory of history responses, since ' |
|
f'it exceeds the length limitation {args.max_new_tokens}.') |
|
n_turn = 0 |
|
inputs = '' |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|