Spaces:
Running
on
Zero
Running
on
Zero
Update demo/infer.py
Browse files- demo/infer.py +6 -6
demo/infer.py
CHANGED
|
@@ -56,7 +56,7 @@ class LiveCCDemoInfer:
|
|
| 56 |
self,
|
| 57 |
message: str,
|
| 58 |
state: dict,
|
| 59 |
-
max_pixels: int =
|
| 60 |
default_query: str = 'Please describe the video.',
|
| 61 |
do_sample: bool = False,
|
| 62 |
repetition_penalty: float = 1.05,
|
|
@@ -122,20 +122,20 @@ class LiveCCDemoInfer:
|
|
| 122 |
# 5. make conversation and send to model
|
| 123 |
for clip, timestamps in zip(interleave_clips, interleave_timestamps):
|
| 124 |
start_timestamp, stop_timestamp = timestamps[0].item(), timestamps[-1].item() + self.frame_time_interval
|
| 125 |
-
|
| 126 |
"role": "user",
|
| 127 |
"content": [
|
| 128 |
{"type": "text", "text": f'Time={start_timestamp:.1f}-{stop_timestamp:.1f}s'},
|
| 129 |
{"type": "video", "video": clip}
|
| 130 |
]
|
| 131 |
-
}
|
| 132 |
if not message and not state.get('message', None):
|
| 133 |
message = default_query
|
| 134 |
logger.warning(f'No query provided, use default_query={default_query}')
|
| 135 |
if message and state.get('message', None) != message:
|
| 136 |
-
|
| 137 |
state['message'] = message
|
| 138 |
-
texts = self.processor.apply_chat_template(
|
| 139 |
past_ids = state.get('past_ids', None)
|
| 140 |
if past_ids is not None:
|
| 141 |
texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
|
|
@@ -146,7 +146,6 @@ class LiveCCDemoInfer:
|
|
| 146 |
return_tensors="pt",
|
| 147 |
return_attention_mask=False
|
| 148 |
)
|
| 149 |
-
print(texts)
|
| 150 |
inputs.to(self.model.device)
|
| 151 |
if past_ids is not None:
|
| 152 |
inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
|
|
@@ -159,6 +158,7 @@ class LiveCCDemoInfer:
|
|
| 159 |
return_dict_in_generate=True, do_sample=do_sample,
|
| 160 |
repetition_penalty=repetition_penalty,
|
| 161 |
logits_processor=logits_processor,
|
|
|
|
| 162 |
)
|
| 163 |
state['past_key_values'] = outputs.past_key_values
|
| 164 |
state['past_ids'] = outputs.sequences[:, :-1]
|
|
|
|
| 56 |
self,
|
| 57 |
message: str,
|
| 58 |
state: dict,
|
| 59 |
+
max_pixels: int = 384 * 28 * 28,
|
| 60 |
default_query: str = 'Please describe the video.',
|
| 61 |
do_sample: bool = False,
|
| 62 |
repetition_penalty: float = 1.05,
|
|
|
|
| 122 |
# 5. make conversation and send to model
|
| 123 |
for clip, timestamps in zip(interleave_clips, interleave_timestamps):
|
| 124 |
start_timestamp, stop_timestamp = timestamps[0].item(), timestamps[-1].item() + self.frame_time_interval
|
| 125 |
+
conversation = [{
|
| 126 |
"role": "user",
|
| 127 |
"content": [
|
| 128 |
{"type": "text", "text": f'Time={start_timestamp:.1f}-{stop_timestamp:.1f}s'},
|
| 129 |
{"type": "video", "video": clip}
|
| 130 |
]
|
| 131 |
+
}]
|
| 132 |
if not message and not state.get('message', None):
|
| 133 |
message = default_query
|
| 134 |
logger.warning(f'No query provided, use default_query={default_query}')
|
| 135 |
if message and state.get('message', None) != message:
|
| 136 |
+
conversation[0]['content'].append({"type": "text", "text": message})
|
| 137 |
state['message'] = message
|
| 138 |
+
texts = self.processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True, return_tensors='pt')
|
| 139 |
past_ids = state.get('past_ids', None)
|
| 140 |
if past_ids is not None:
|
| 141 |
texts = '<|im_end|>\n' + texts[self.system_prompt_offset:]
|
|
|
|
| 146 |
return_tensors="pt",
|
| 147 |
return_attention_mask=False
|
| 148 |
)
|
|
|
|
| 149 |
inputs.to(self.model.device)
|
| 150 |
if past_ids is not None:
|
| 151 |
inputs['input_ids'] = torch.cat([past_ids, inputs.input_ids], dim=1)
|
|
|
|
| 158 |
return_dict_in_generate=True, do_sample=do_sample,
|
| 159 |
repetition_penalty=repetition_penalty,
|
| 160 |
logits_processor=logits_processor,
|
| 161 |
+
max_new_tokens=16,
|
| 162 |
)
|
| 163 |
state['past_key_values'] = outputs.past_key_values
|
| 164 |
state['past_ids'] = outputs.sequences[:, :-1]
|