chenjoya commited on
Commit
dc317e6
·
verified ·
1 Parent(s): bc5cb8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -14,15 +14,21 @@ class GradioBackend:
14
  'Conversation': 'video_qa'
15
  }
16
 
17
- def __init__(self, ):
18
- self.infer = LiveCCDemoInfer(model_path, device='cpu')
19
- self.audio_pipeline = KPipeline(lang_code='a')
 
20
 
21
- def __call__(self, query: str = None, state: dict = {}, mode: str = 'Real-Time Commentary', **kwargs):
22
- return getattr(self.infer, self.mode2api[mode])(query=query, state=state, **kwargs)
 
 
 
23
 
24
- def to(self, device):
25
- self.infer.model.to(device)
 
 
26
 
27
  with gr.Blocks() as demo:
28
  gr.Markdown("## LiveCC Real-Time Commentary and Conversation - Gradio Demo")
@@ -65,10 +71,12 @@ with gr.Blocks() as demo:
65
 
66
  @spaces.GPU
67
  def gr_chatinterface_fn(message, history, state, video_path, mode):
68
- gradio_backend.to('cuda')
 
 
69
  state['video_path'] = video_path
70
- response, state = gradio_backend(query=message, state=state, mode=mode)
71
- return response, state
72
  def gr_chatinterface_chatbot_clear_fn():
73
  return {}, {}, 0, 0
74
  gr_chatinterface = gr.ChatInterface(
 
14
  'Conversation': 'video_qa'
15
  }
16
 
17
+ def __init__(self):
18
+ # Delay model loading until we're in a GPU context
19
+ self.infer = None
20
+ self.audio_pipeline = None
21
 
22
+ def init_model(self, device):
23
+ # Instantiate inside GPU process
24
+ if self.infer is None:
25
+ self.infer = LiveCCDemoInfer(model_path, device=device)
26
+ self.audio_pipeline = KPipeline(lang_code='a')
27
 
28
+ def __call__(self, query: str = None, state: dict = {}, mode: str = 'Real-Time Commentary', **kwargs):
29
+ # Called only inside GPU process
30
+ response, state = getattr(self.infer, self.mode2api[mode])(query=query, state=state, **kwargs)
31
+ return response, state
32
 
33
  with gr.Blocks() as demo:
34
  gr.Markdown("## LiveCC Real-Time Commentary and Conversation - Gradio Demo")
 
71
 
72
  @spaces.GPU
73
  def gr_chatinterface_fn(message, history, state, video_path, mode):
74
+ # Initialize backend and move model to GPU inside this process
75
+ global gradio_backend
76
+ gradio_backend.init_model(device='cuda')
77
  state['video_path'] = video_path
78
+ return gradio_backend(query=message, state=state, mode=mode)
79
+
80
  def gr_chatinterface_chatbot_clear_fn():
81
  return {}, {}, 0, 0
82
  gr_chatinterface = gr.ChatInterface(