chenjoya commited on
Commit
177b037
·
1 Parent(s): 5d35e25

fix flashattn

Browse files
Files changed (2) hide show
  1. app.py +0 -2
  2. demo/infer.py +6 -0
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import os
2
- os.system('pip install flash-attn --no-build-isolation')
3
  import gradio as gr
4
 
5
  from demo.infer import LiveCCDemoInfer
 
 
 
1
  import gradio as gr
2
 
3
  from demo.infer import LiveCCDemoInfer
demo/infer.py CHANGED
@@ -32,6 +32,12 @@ class LiveCCDemoInfer:
32
  streaming_time_interval = streaming_fps_frames / fps
33
  frame_time_interval = 1 / fps
34
  def __init__(self, model_path: str = None, device_id: int = 0):
 
 
 
 
 
 
35
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
36
  model_path, torch_dtype="auto",
37
  device_map=f'cuda:{device_id}',
 
32
  streaming_time_interval = streaming_fps_frames / fps
33
  frame_time_interval = 1 / fps
34
  def __init__(self, model_path: str = None, device_id: int = 0):
35
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
36
+ model_path, torch_dtype="auto",
37
+ device_map=f'cuda:{device_id}',
38
+ )
39
+ import os
40
+ os.system('pip install flash-attn --no-build-isolation')
41
  self.model = Qwen2VLForConditionalGeneration.from_pretrained(
42
  model_path, torch_dtype="auto",
43
  device_map=f'cuda:{device_id}',