import os import torch from flashsloth.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, LEARNABLE_TOKEN, LEARNABLE_TOKEN_INDEX ) from flashsloth.conversation import conv_templates, SeparatorStyle from flashsloth.model.builder import load_pretrained_model from flashsloth.utils import disable_torch_init from flashsloth.mm_utils import ( tokenizer_image_token, process_images, process_images_hd_inference, get_model_name_from_path, KeywordsStoppingCriteria ) from PIL import Image import gradio as gr from transformers import TextIteratorStreamer from threading import Thread disable_torch_init() MODEL_PATH_HD = "Tongbo/FlashSloth_HD-3.2B" MODEL_PATH_NEW = "Tongbo/FlashSloth-3.2B" model_name_hd = get_model_name_from_path(MODEL_PATH_HD) model_name_new = get_model_name_from_path(MODEL_PATH_NEW) models = { "FlashSloth HD": load_pretrained_model(MODEL_PATH_HD, None, model_name_hd), "FlashSloth": load_pretrained_model(MODEL_PATH_NEW, None, model_name_new) } for key in models: tokenizer, model, image_processor, context_len = models[key] model.to('cuda') model.eval() def generate_description(image, prompt_text, temperature, top_p, max_tokens, selected_model): """ 生成图片描述的函数,支持流式输出,并根据选择的模型进行处理。 新增参数: - selected_model: 用户选择的模型名称 """ keywords = [''] tokenizer, model, image_processor, context_len = models[selected_model] text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text text = text + LEARNABLE_TOKEN image = image.convert('RGB') if model.config.image_hd: image_tensor = process_images_hd_inference([image], image_processor, model.config)[0] else: image_tensor = process_images([image], image_processor, model.config)[0] image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True) conv = conv_templates["phi2"].copy() conv.append_message(conv.roles[0], text) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True) stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextIteratorStreamer( tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True ) generation_kwargs = dict( inputs=input_ids, images=image_tensor, do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=int(max_tokens), use_cache=True, eos_token_id=tokenizer.eos_token_id, stopping_criteria=[stopping_criteria], streamer=streamer ) def _generate(): with torch.inference_mode(): model.generate(**generation_kwargs) generation_thread = Thread(target=_generate) generation_thread.start() partial_text = "" for new_text in streamer: partial_text += new_text yield partial_text generation_thread.join() custom_css = """ """ with gr.Blocks(css=custom_css) as demo: gr.HTML(custom_css) gr.HTML("