BK-Lee commited on
Commit
15b745f
·
1 Parent(s): 734b102
Files changed (2) hide show
  1. app.py +11 -12
  2. meteor/arch/modeling_internlm2.py +0 -3
app.py CHANGED
@@ -14,19 +14,9 @@ from transformers import TextIteratorStreamer
14
  from torchvision.transforms.functional import pil_to_tensor
15
 
16
  # loading meteor model
17
- mmamba = load_mmamba('BK-Lee/Meteor-Mamba').to('cuda')
18
  meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=16)
19
 
20
- # param
21
- for param in mmamba.parameters():
22
- param = param.to('cuda')
23
- for param in meteor.parameters():
24
- param = param.to('cuda')
25
-
26
- # device
27
- device = torch.cuda.current_device()
28
-
29
-
30
  # freeze model
31
  freeze_model(mmamba)
32
  freeze_model(meteor)
@@ -36,7 +26,16 @@ previous_length = 0
36
 
37
  @spaces.GPU
38
  def threading_function(inputs, image_token_number, streamer):
39
- print(f'----------------------------Device: {device}----------------------------')
 
 
 
 
 
 
 
 
 
40
  # Meteor Mamba
41
  mmamba_inputs = mmamba.eval_process(inputs=inputs, tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
42
  if 'image' in mmamba_inputs.keys():
 
14
  from torchvision.transforms.functional import pil_to_tensor
15
 
16
  # loading meteor model
17
+ mmamba = load_mmamba('BK-Lee/Meteor-Mamba').cuda()
18
  meteor, tok_meteor = load_meteor('BK-Lee/Meteor-MLM', bits=16)
19
 
 
 
 
 
 
 
 
 
 
 
20
  # freeze model
21
  freeze_model(mmamba)
22
  freeze_model(meteor)
 
26
 
27
  @spaces.GPU
28
  def threading_function(inputs, image_token_number, streamer):
29
+
30
+ # device
31
+ device = torch.cuda.current_device()
32
+
33
+ # param
34
+ for param in mmamba.parameters():
35
+ param.data = param.to(device)
36
+ for param in meteor.parameters():
37
+ param.data = param.to(device)
38
+
39
  # Meteor Mamba
40
  mmamba_inputs = mmamba.eval_process(inputs=inputs, tokenizer=tok_meteor, device=device, img_token_number=image_token_number)
41
  if 'image' in mmamba_inputs.keys():
meteor/arch/modeling_internlm2.py CHANGED
@@ -277,9 +277,6 @@ def rotate_half(x):
277
  # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
278
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
279
  """Applies Rotary Position Embedding to the query and key tensors."""
280
- print(f'------------------------------cos: {cos.device}------------------------------')
281
- print(f'------------------------------position_ids: {position_ids.device}------------------------------')
282
- print(f'------------------------------unsqueeze_dim: {unsqueeze_dim.device}------------------------------')
283
  cos = cos[position_ids].unsqueeze(unsqueeze_dim)
284
  sin = sin[position_ids].unsqueeze(unsqueeze_dim)
285
  q_embed = (q * cos) + (rotate_half(q) * sin)
 
277
  # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
278
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
279
  """Applies Rotary Position Embedding to the query and key tensors."""
 
 
 
280
  cos = cos[position_ids].unsqueeze(unsqueeze_dim)
281
  sin = sin[position_ids].unsqueeze(unsqueeze_dim)
282
  q_embed = (q * cos) + (rotate_half(q) * sin)