Softie / app.py
Pectics's picture
更新
bd5a97e
raw
history blame
2.69 kB
from spaces import GPU
from threading import Thread
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, TextIteratorStreamer, AutoProcessor, BatchFeature
from qwen_vl_utils import process_vision_info
from gradio import ChatInterface, Slider
model_path = "Pectics/Softie-VL-7B-250123"
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype="auto",
device_map="auto",
attn_implementation="flash_attention_2",
)
min_pixels = 256 * 28 * 28
max_pixels = 1280 * 28 * 28
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path, min_pixels=min_pixels, max_pixels=max_pixels)
@GPU
def infer(
inputs: tuple,
max_tokens: int,
temperature: float,
top_p: float,
):
inputs = processor(
text=[inputs[0]],
images=inputs[1],
videos=inputs[2],
padding=True,
return_tensors="pt",
).to("cuda")
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
Thread(target=model.generate, kwargs=kwargs).start()
response = ""
for token in streamer:
response += token
yield response
def respond(
message: str | list[object],
history: list[object],
max_tokens: int,
temperature: float,
top_p: float,
):
print('message', message)
print('history', history)
if isinstance(message, str):
if len(history) == 0 or history[0]['role'] != 'system':
history.insert(0, {"role": "system", "content": """You are Softie, or 小软 in Chinese.
You are an intelligent assistant developed by the School of Software at Hefei University of Technology.
You like to chat with people and help them solve problems."""})
history.append({"role": "user", "content": message})
message = history
text_inputs = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(message)
for response in infer((text_inputs, image_inputs, video_inputs), max_tokens, temperature, top_p):
yield response
app = ChatInterface(
respond,
type="messages",
additional_inputs=[
Slider(minimum=1, maximum=2048, value=512, step=1, label="最大生成长度"),
Slider(minimum=0.01, maximum=4.0, value=0.75, step=0.01, label="温度系数(Temperature)"),
Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.01, label="核取样系数(Top-p)"),
],
)
if __name__ == "__main__":
app.launch()