hiko1999 commited on
Commit
146e1eb
·
1 Parent(s): 094254f

Add Gradio application

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
3
+ from qwen_vl_utils import process_vision_info
4
+ import gradio as gr
5
+ from PIL import Image
6
+
7
+ # Hugging Face 模型仓库路径
8
+ model_path = "hiko1999/Qwen2-Wildfire-VL-2B-Instruct" # 替换为你的模型路径
9
+
10
+ # 加载 Hugging Face 上的模型和 processor
11
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
12
+ model = Qwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
13
+ processor = AutoProcessor.from_pretrained(model_path)
14
+
15
+ # 定义预测函数
16
+ def predict(image):
17
+ # 将上传的图片处理为模型需要的格式
18
+ messages = [{"role": "user",
19
+ "content": [{"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}]}]
20
+
21
+ # 处理图片输入
22
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
23
+ image_inputs, video_inputs = process_vision_info(messages)
24
+ inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
25
+ inputs = inputs.to("cuda") # 转移到GPU
26
+
27
+ # 生成模型输出
28
+ generated_ids = model.generate(**inputs, max_new_tokens=128)
29
+ generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
30
+ output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True,
31
+ clean_up_tokenization_spaces=False)
32
+
33
+ return output_text[0] # 返回生成的文本
34
+
35
+ # Gradio界面
36
+ def gradio_interface(image):
37
+ result = predict(image)
38
+ return f"预测结果:{result}"
39
+
40
+ # 创建Gradio接口
41
+ interface = gr.Interface(fn=gradio_interface,
42
+ inputs=gr.Image(type="pil"), # 输入的图像
43
+ outputs="text", # 输出结果
44
+ title="火灾场景多模态模型预测",
45
+ description="上传图片进行火灾预测。")
46
+
47
+ # 启动接口
48
+ interface.launch()