XiaoyiYangRIT commited on
Commit
78b768f
·
1 Parent(s): 10625bf

update some files

Browse files
Files changed (3) hide show
  1. app.py +33 -22
  2. src/prompt.py +0 -7
  3. src/video_utils.py +12 -10
app.py CHANGED
@@ -1,31 +1,42 @@
1
- # app.py(主入口简化版)
2
  import gradio as gr
3
  from src.model_loader import load_model
4
  from src.video_utils import process_video_for_internvl3
 
5
 
6
- # === 初始化模型 ===
7
  tokenizer, model = load_model()
8
 
9
- # === 推理接口 ===
10
- def evaluate_ar(video):
11
- pixel_values, num_patches_list, prompt = process_video_for_internvl3(video)
12
- generation_config = dict(max_new_tokens=512)
13
- output, _ = model.chat(
14
- tokenizer,
15
- pixel_values,
16
- prompt,
17
- generation_config=generation_config,
18
- num_patches_list=num_patches_list,
19
- history=None,
20
- return_history=True
21
- )
22
- return output
23
-
24
- # === Gradio 接口 ===
 
 
 
 
 
 
 
 
 
 
 
25
  gr.Interface(
26
- fn=evaluate_ar,
27
  inputs=gr.Video(label="Upload your AR video"),
28
  outputs="text",
29
- title="InternVL3 AR Evaluation (Single-turn)",
30
- description="Upload a short AR video clip. The model will sample frames and assess occlusion/rendering quality."
31
- ).launch()
 
1
+ # app.py
2
  import gradio as gr
3
  from src.model_loader import load_model
4
  from src.video_utils import process_video_for_internvl3
5
+ from src.ar_prompts import generate_conversation_questions
6
 
 
7
  tokenizer, model = load_model()
8
 
9
+ def evaluate_ar_multi_turn(video):
10
+ pixel_values, num_patches_list, image_prefix = process_video_for_internvl3(video)
11
+ conversation = generate_conversation_questions(include_descriptions=True)
12
+
13
+ history = None
14
+ visible_outputs = []
15
+
16
+ for i, question in enumerate(conversation):
17
+ prompt = image_prefix + question if i == 0 else question
18
+
19
+ output, history = model.chat(
20
+ tokenizer,
21
+ pixel_values,
22
+ prompt,
23
+ generation_config={"max_new_tokens": 1024},
24
+ num_patches_list=num_patches_list,
25
+ history=history,
26
+ return_history=True
27
+ )
28
+
29
+ # 仅保留评测和拓展部分的回答(即从第3轮开始)
30
+ if i >= 2:
31
+ visible_outputs.append(output)
32
+
33
+ # 多个输出拼接成文本显示
34
+ return "\n\n".join(visible_outputs)
35
+
36
  gr.Interface(
37
+ fn=evaluate_ar_multi_turn,
38
  inputs=gr.Video(label="Upload your AR video"),
39
  outputs="text",
40
+ title="InternVL3 AR Evaluation (Multi-turn)",
41
+ description="Upload a short AR video clip. The model will sample frames and conduct a multi-turn dialogue to assess occlusion/rendering/placement/lighting."
42
+ ).launch()
src/prompt.py DELETED
@@ -1,7 +0,0 @@
1
- # src/prompt.py
2
-
3
- def build_video_prompt(num_frames: int) -> str:
4
- """构建适用于 InternVL3 的单轮 AR 视频评估提示语。"""
5
- frame_descriptors = ''.join([f"Frame{i+1}: <image>\n" for i in range(num_frames)])
6
- final_prompt = frame_descriptors + "Evaluate the quality of AR occlusion and rendering in the uploaded video."
7
- return final_prompt
 
 
 
 
 
 
 
 
src/video_utils.py CHANGED
@@ -1,16 +1,15 @@
1
- # src/video_utils.py
2
  import numpy as np
3
  import torch
4
  from PIL import Image
5
  from decord import VideoReader, cpu
6
  import torchvision.transforms as T
7
  from torchvision.transforms.functional import InterpolationMode
8
- from src.prompt import build_video_prompt
9
 
10
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
11
  IMAGENET_STD = (0.229, 0.224, 0.225)
12
 
13
- # === 构建标准图像预处理 transform ===
14
  def build_transform(input_size=448):
15
  return T.Compose([
16
  T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
@@ -19,12 +18,16 @@ def build_transform(input_size=448):
19
  T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
20
  ])
21
 
22
- # === InternVL3 视频帧采样策略 ===
23
  def get_frame_indices(num_frames, total_frames):
24
  indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
25
  return indices
26
 
27
- # === 从视频中提取图像帧并预处理成 patch tensor ===
 
 
 
 
28
  def process_video_for_internvl3(video_path, num_segments=8, max_patch_per_frame=1, input_size=448):
29
  vr = VideoReader(video_path, ctx=cpu(0))
30
  total_frames = len(vr)
@@ -41,16 +44,15 @@ def process_video_for_internvl3(video_path, num_segments=8, max_patch_per_frame=
41
  num_patches_list.append(patch_tensor.shape[0])
42
 
43
  pixel_values = torch.cat(pixel_values_list, dim=0).to(torch.bfloat16).cuda()
44
- prompt = build_video_prompt(len(num_patches_list))
45
 
46
- return pixel_values, num_patches_list, prompt
47
 
48
- # === 图像切片为 patch 区块 ===
49
  def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=True):
50
  orig_width, orig_height = image.size
51
  aspect_ratio = orig_width / orig_height
52
 
53
- # 构造备选分块比率
54
  target_ratios = set(
55
  (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1)
56
  if i * j <= max_num and i * j >= min_num
@@ -80,7 +82,7 @@ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbna
80
 
81
  return processed_images
82
 
83
- # === 找出最接近原图比例的块切方案 ===
84
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
85
  best_ratio_diff = float('inf')
86
  best_ratio = (1, 1)
 
1
+ # src/video_utils.py(返回 <image> prefix 支持多轮对话)
2
  import numpy as np
3
  import torch
4
  from PIL import Image
5
  from decord import VideoReader, cpu
6
  import torchvision.transforms as T
7
  from torchvision.transforms.functional import InterpolationMode
 
8
 
9
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
10
  IMAGENET_STD = (0.229, 0.224, 0.225)
11
 
12
+ # 图像预处理 transform
13
  def build_transform(input_size=448):
14
  return T.Compose([
15
  T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
 
18
  T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
19
  ])
20
 
21
+ # 视频帧采样策略
22
  def get_frame_indices(num_frames, total_frames):
23
  indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
24
  return indices
25
 
26
+ # 构建 <image> token 的前缀信息
27
+ def build_image_prefix(num_frames: int) -> str:
28
+ return ''.join([f"Frame{i+1}: <image>\n" for i in range(num_frames)])
29
+
30
+ # 视频处理为 patch tensor,并返回 <image> 前缀
31
  def process_video_for_internvl3(video_path, num_segments=8, max_patch_per_frame=1, input_size=448):
32
  vr = VideoReader(video_path, ctx=cpu(0))
33
  total_frames = len(vr)
 
44
  num_patches_list.append(patch_tensor.shape[0])
45
 
46
  pixel_values = torch.cat(pixel_values_list, dim=0).to(torch.bfloat16).cuda()
47
+ image_prefix = build_image_prefix(len(num_patches_list))
48
 
49
+ return pixel_values, num_patches_list, image_prefix
50
 
51
+ # 图像切块
52
  def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=True):
53
  orig_width, orig_height = image.size
54
  aspect_ratio = orig_width / orig_height
55
 
 
56
  target_ratios = set(
57
  (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1)
58
  if i * j <= max_num and i * j >= min_num
 
82
 
83
  return processed_images
84
 
85
+ # 找最接近原图比例的切块方案
86
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
87
  best_ratio_diff = float('inf')
88
  best_ratio = (1, 1)