momergul commited on
Commit
ab71bac
1 Parent(s): d1a5104

Reverted to old form

Browse files
Files changed (1) hide show
  1. app.py +7 -41
app.py CHANGED
@@ -8,45 +8,7 @@ from typing import List, Tuple
8
 
9
  from config_generator import generate_complete_game
10
  from dataset import get_processor, joint_speaker_input, joint_listener_input, get_index_to_token
11
-
12
- import torch
13
- import transformers
14
- from transformers import Idefics2ForConditionalGeneration
15
- from peft import LoraConfig, get_peft_model
16
- from joint_inference import IdeficsJointInferenceModel
17
-
18
- # Initialize the model globally
19
- repo = 'lil-lab/cogen'
20
- checkpoint = "HuggingFaceM4/idefics2-8b"
21
- model = Idefics2ForConditionalGeneration.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
22
-
23
- target_modules=r'(.*(vision_model|modality_projection|perceiver_resampler).*(out_proj|fc1|fc2|down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)|(.*(k_proj|q_proj|v_proj).*$)'
24
- lora_config = LoraConfig(
25
- r=16, lora_alpha=8,
26
- lora_dropout=0.1,
27
- target_modules=target_modules,
28
- init_lora_weights="gaussian"
29
- )
30
- model = get_peft_model(model, lora_config, adapter_name="initial")
31
- model.load_adapter(repo, "initial", revision="r0_full")
32
-
33
- # Add other adapter
34
- new_targets = set()
35
- for n, p in model.named_parameters():
36
- if 'lora' in n:
37
- new_targets.add(n[17:n.find('lora')-1])
38
- new_targets = list(new_targets)
39
-
40
- lora_config = LoraConfig(
41
- r=16, lora_alpha=8,
42
- lora_dropout=0.1,
43
- target_modules=new_targets,
44
- init_lora_weights="gaussian"
45
- )
46
- model.add_adapter('final', lora_config)
47
- model.load_adapter(repo, "final", revision="r3_full")
48
- model = IdeficsJointInferenceModel(0.5, 0, model=model).cuda()
49
- model.eval()
50
 
51
  css="""
52
  .radio-group .wrap {
@@ -110,6 +72,7 @@ def get_model_response(
110
  def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token, adapter_name):
111
  if model.model.active_adapter != adapter_name:
112
  model.model.set_adapter(adapter_name)
 
113
  with torch.no_grad():
114
  captions, _, _, _, _ = model.generate(
115
  images.cuda(), input_tokens.cuda(), attn_mask.cuda(), image_attn_mask.cuda(), label.cuda(),
@@ -124,6 +87,7 @@ def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_at
124
  s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name):
125
  if model.model.active_adapter != adapter_name:
126
  model.model.set_adapter(adapter_name)
 
127
  with torch.no_grad():
128
  _, _, joint_log_probs = model.comprehension_side([
129
  images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token,
@@ -155,7 +119,7 @@ def initialize_interaction(model_iteration):
155
 
156
  return new_history
157
 
158
- def progress_game(user_message, processor, index_to_token, current_state):
159
  # First get the game state
160
  turn = current_state['turn']
161
  image_role_pairs = current_state['image_role_pairs']
@@ -293,6 +257,7 @@ def create_app():
293
  )
294
 
295
  send_btn = gr.Button("Send", interactive=False)
 
296
  processor = get_processor()
297
  index_to_token = get_index_to_token()
298
 
@@ -316,6 +281,7 @@ def create_app():
316
  gr.update(interactive=not human_listener), gr.update(interactive=human_listener), gr.update(interactive=True), gr.update(interactive=False), current_history
317
 
318
  def send_message(message, radio_choice, current_state):
 
319
  nonlocal processor
320
  nonlocal index_to_token
321
 
@@ -326,7 +292,7 @@ def create_app():
326
 
327
  # Regular game progress
328
  user_output = message if radio_choice is None else radio_choice
329
- images, conversation, role, turn, acc_message, current_state = progress_game(user_output, processor, index_to_token, current_state)
330
  human_listener = role == "Listener"
331
  return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, \
332
  acc_message, gr.update(interactive=not human_listener, value=""), gr.update(interactive=human_listener, value=None), \
 
8
 
9
  from config_generator import generate_complete_game
10
  from dataset import get_processor, joint_speaker_input, joint_listener_input, get_index_to_token
11
+ from models import get_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  css="""
14
  .radio-group .wrap {
 
72
  def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token, adapter_name):
73
  if model.model.active_adapter != adapter_name:
74
  model.model.set_adapter(adapter_name)
75
+ model = model.cuda()
76
  with torch.no_grad():
77
  captions, _, _, _, _ = model.generate(
78
  images.cuda(), input_tokens.cuda(), attn_mask.cuda(), image_attn_mask.cuda(), label.cuda(),
 
87
  s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name):
88
  if model.model.active_adapter != adapter_name:
89
  model.model.set_adapter(adapter_name)
90
+ model = model.cuda()
91
  with torch.no_grad():
92
  _, _, joint_log_probs = model.comprehension_side([
93
  images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token,
 
119
 
120
  return new_history
121
 
122
+ def progress_game(user_message, model, processor, index_to_token, current_state):
123
  # First get the game state
124
  turn = current_state['turn']
125
  image_role_pairs = current_state['image_role_pairs']
 
257
  )
258
 
259
  send_btn = gr.Button("Send", interactive=False)
260
+ model = get_model()
261
  processor = get_processor()
262
  index_to_token = get_index_to_token()
263
 
 
281
  gr.update(interactive=not human_listener), gr.update(interactive=human_listener), gr.update(interactive=True), gr.update(interactive=False), current_history
282
 
283
  def send_message(message, radio_choice, current_state):
284
+ nonlocal model
285
  nonlocal processor
286
  nonlocal index_to_token
287
 
 
292
 
293
  # Regular game progress
294
  user_output = message if radio_choice is None else radio_choice
295
+ images, conversation, role, turn, acc_message, current_state = progress_game(user_output, model, processor, index_to_token, current_state)
296
  human_listener = role == "Listener"
297
  return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, \
298
  acc_message, gr.update(interactive=not human_listener, value=""), gr.update(interactive=human_listener, value=None), \