BK-Lee commited on
Commit
2acb8d8
1 Parent(s): 7e1e0aa
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -25,16 +25,7 @@ freeze_model(meteor)
25
  previous_length = 0
26
 
27
  @spaces.GPU
28
- def threading_function(inputs, image_token_number, streamer):
29
-
30
- # device
31
- device = torch.cuda.current_device()
32
-
33
- # param
34
- for param in mmamba.parameters():
35
- param.data = param.to(device)
36
- for param in meteor.parameters():
37
- param.data = param.to(device)
38
 
39
  # Meteor Mamba
40
  mmamba_inputs = mmamba.eval_process(inputs=inputs, tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
@@ -68,6 +59,15 @@ def add_message(history, message):
68
  @spaces.GPU
69
  def bot_streaming(message, history):
70
 
 
 
 
 
 
 
 
 
 
71
  # prompt type -> input prompt
72
  image_token_number = int((490/14)**2)
73
  if len(message['files']) != 0:
@@ -83,7 +83,7 @@ def bot_streaming(message, history):
83
  streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True)
84
 
85
  # Threading generation
86
- thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, image_token_number=image_token_number, streamer=streamer))
87
  thread.start()
88
 
89
  # generated text
 
25
  previous_length = 0
26
 
27
  @spaces.GPU
28
+ def threading_function(inputs, image_token_number, streamer, device):
 
 
 
 
 
 
 
 
 
29
 
30
  # Meteor Mamba
31
  mmamba_inputs = mmamba.eval_process(inputs=inputs, tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
 
59
  @spaces.GPU
60
  def bot_streaming(message, history):
61
 
62
+ # device
63
+ device = torch.cuda.current_device()
64
+
65
+ # param
66
+ for param in mmamba.parameters():
67
+ param.data = param.to(device)
68
+ for param in meteor.parameters():
69
+ param.data = param.to(device)
70
+
71
  # prompt type -> input prompt
72
  image_token_number = int((490/14)**2)
73
  if len(message['files']) != 0:
 
83
  streamer = TextIteratorStreamer(tok_meteor, skip_special_tokens=True)
84
 
85
  # Threading generation
86
+ thread = Thread(target=threading_function, kwargs=dict(inputs=inputs, image_token_number=image_token_number, streamer=streamer, device=device))
87
  thread.start()
88
 
89
  # generated text