chenjoya commited on
Commit
cfa935f
·
1 Parent(s): 096fa24
Files changed (1) hide show
  1. demo/infer.py +3 -2
demo/infer.py CHANGED
@@ -34,9 +34,10 @@ class LiveCCDemoInfer:
34
  def __init__(self, model_path: str = None, device_id: int = 0):
35
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
36
  model_path, torch_dtype="auto",
37
- device_map=f'cuda:{device_id}',
38
- attn_implementation='flash_attention_2'
39
  )
 
40
  self.processor = AutoProcessor.from_pretrained(model_path, use_fast=False)
41
  self.streaming_eos_token_id = self.processor.tokenizer(' ...').input_ids[-1]
42
  self.model.prepare_inputs_for_generation = functools.partial(prepare_multiturn_multimodal_inputs_for_generation, self.model)
 
34
  def __init__(self, model_path: str = None, device_id: int = 0):
35
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
36
  model_path, torch_dtype="auto",
37
+ # device_map=f'cuda:{device_id}',
38
+ # attn_implementation='flash_attention_2'
39
  )
40
+ self.model.to('cuda')
41
  self.processor = AutoProcessor.from_pretrained(model_path, use_fast=False)
42
  self.streaming_eos_token_id = self.processor.tokenizer(' ...').input_ids[-1]
43
  self.model.prepare_inputs_for_generation = functools.partial(prepare_multiturn_multimodal_inputs_for_generation, self.model)