LN1996 commited on
Commit
6fd64e4
·
verified ·
1 Parent(s): 0aeb624

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ from transformers import AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
5
+
6
+ model_name = "microsoft/phi-2"
7
+ phi2_model_pretrained = AutoModelForCausalLM.from_pretrained(
8
+ model_name,
9
+ trust_remote_code=True,
10
+ device_map = 'cpu'
11
+ )
12
+
13
+ phi2_model_pretrained.config.use_cache = False
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
16
+ tokenizer.pad_token = tokenizer.eos_token
17
+ tokenizer.bos_token = tokenizer.eos_token
18
+
19
+ def convert_text_input_embeds(text):
20
+
21
+ in_tokens = tokenizer(text, return_tensors="pt", return_attention_mask=False)
22
+ in_embeds = phi2_model_pretrained.get_input_embeddings()(in_tokens.input_ids)
23
+
24
+ return in_embeds
25
+
26
+ import whisperx
27
+
28
+ whisper_model = whisperx.load_model('small', device='cpu', compute_type='float32')
29
+
30
+ def convert_audio_file_text_embeds(fname):
31
+ result = whisper_model.transcribe(fname)
32
+ full_text = ''
33
+ for seg in result['segments']:
34
+ full_text = full_text + seg['text']
35
+ return full_text.strip()
36
+
37
+ from transformers import CLIPVisionModel, CLIPImageProcessor
38
+
39
+ vision_tower_name = 'openai/clip-vit-base-patch32' ## torch.Size([1, 49, 768])
40
+ image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
41
+ vision_tower = CLIPVisionModel.from_pretrained(vision_tower_name)
42
+
43
+ def feature_select(image_forward_outs):
44
+
45
+ image_features = image_forward_outs.hidden_states[-1] # last layer
46
+ image_features = image_features[:, 1:, :]
47
+ return image_features # [1, 49, 768]
48
+
49
+ def image_CLIP_embed(image):
50
+
51
+ _ = vision_tower.requires_grad_(False)
52
+ image = image_processor(images=image, return_tensors="pt")
53
+ image_forward_out = vision_tower(image['pixel_values'].to(device=vision_tower.device), output_hidden_states=True)
54
+ image_feature = feature_select(image_forward_out)
55
+
56
+ return image_feature
57
+
58
+ import torch
59
+ import torch.nn as nn
60
+ import torch.nn.functional as F
61
+
62
+ class CustomGELU(nn.Module):
63
+ def forward(self, x):
64
+ return F.gelu(x.clone())
65
+
66
+ class SimpleResBlock(nn.Module):
67
+ def __init__(self, input_size):
68
+ super().__init__()
69
+ self.pre_norm = nn.LayerNorm(input_size)
70
+ self.proj = nn.Sequential(
71
+ nn.Linear(input_size, input_size),
72
+ nn.GELU(),
73
+ nn.Linear(input_size, input_size)
74
+ )
75
+ def forward(self, x):
76
+ x = self.pre_norm(x)
77
+ return x + self.proj(x)
78
+
79
+ class CLIPembed_projection(nn.Module):
80
+ def __init__(self, input_dim_CLIP=768, input_dim_phi2=2560):
81
+ super(CLIPembed_projection, self).__init__()
82
+ self.input_dim_CLIP = input_dim_CLIP
83
+ self.input_dim_phi2 = input_dim_phi2
84
+ self.projection_img = nn.Linear(self.input_dim_CLIP, self.input_dim_phi2,
85
+ bias=False)
86
+ self.resblock = SimpleResBlock(self.input_dim_phi2)
87
+
88
+ def forward(self, x):
89
+
90
+ x = self.projection_img(x)
91
+ x = self.resblock(x)
92
+
93
+ return x
94
+
95
+ Image_projection_layer = CLIPembed_projection()
96
+
97
+ location_projection_img_p1 = f'./weights/stage_2/run2_projection_img.pth'
98
+ location_projection_img_p2 = f'./weights/stage_2/run2_resblock.pth'
99
+
100
+ # load projection_img, resblock from stage 2
101
+ Image_projection_layer.projection_img.load_state_dict(torch.load(location_projection_img_p1, map_location='cpu'))
102
+ Image_projection_layer.resblock.load_state_dict(torch.load(location_projection_img_p2, map_location='cpu'))
103
+
104
+ def img_input_embed(image):
105
+ clip_embed = image_CLIP_embed(image)
106
+ post_projection = Image_projection_layer(clip_embed)
107
+ return post_projection
108
+
109
+ device = 'cpu'
110
+
111
+ user = "LN1996" # put your user name here
112
+ model_name = "peft-qlora-run2"
113
+ model_id = f"{user}/{model_name}"
114
+
115
+ import peft
116
+ phi2_model_pretrained_peft = peft.PeftModel.from_pretrained(phi2_model_pretrained, model_id)
117
+
118
+ def input_multimodel(image=None, audio=None, text=None, query=None):
119
+
120
+ if len(text) == 0:
121
+ text = None
122
+
123
+ if len(query) == 0:
124
+ query = None
125
+
126
+ if query is None:
127
+ print('Please ask a query')
128
+ return None
129
+
130
+ if image is None and audio is None and text is None:
131
+ print('Please provide context in form of image, audio, text')
132
+ return None
133
+
134
+
135
+ bos = tokenizer("Context: ", return_tensors="pt", return_attention_mask=False)
136
+ input_embeds_stage_2 = phi2_model_pretrained_peft.get_input_embeddings()(bos.input_ids)
137
+
138
+ if image is not None:
139
+ image_embeds = img_input_embed(image)
140
+ input_embeds_stage_2 = torch.cat((input_embeds_stage_2, image_embeds), dim=1)
141
+
142
+
143
+ if audio is not None:
144
+ audio_transcribed = convert_audio_file_text_embeds(audio)
145
+ audio_embeds = convert_text_input_embeds(audio_transcribed)
146
+ input_embeds_stage_2 = torch.cat((input_embeds_stage_2, audio_embeds), dim=1)
147
+
148
+
149
+ if text is not None:
150
+ text_embeds = convert_text_input_embeds(text)
151
+ input_embeds_stage_2 = torch.cat((input_embeds_stage_2, text_embeds), dim=1)
152
+
153
+
154
+ qus = tokenizer(" Question: " + query, return_tensors="pt",
155
+ return_attention_mask=False)
156
+
157
+ qus_embeds = phi2_model_pretrained_peft.get_input_embeddings()(qus.input_ids)
158
+ input_embeds_stage_2 = torch.cat((input_embeds_stage_2, qus_embeds), dim=1)
159
+
160
+ ans = tokenizer(" Answer: ", return_tensors="pt", return_attention_mask=False)
161
+ ans_embeds = phi2_model_pretrained_peft.get_input_embeddings()(ans.input_ids)
162
+ input_embeds_stage_2 = torch.cat((input_embeds_stage_2, ans_embeds), dim=1)
163
+
164
+ result = phi2_model_pretrained_peft.generate(inputs_embeds=input_embeds_stage_2,
165
+ bos_token_id = tokenizer.bos_token_id)
166
+
167
+ process = tokenizer.batch_decode(result)[0]
168
+ process = process.split(tokenizer.eos_token)
169
+
170
+ if process[0] == '':
171
+ return process[1]
172
+ else:
173
+ return process[0]
174
+
175
+ title = "Multi-Model Phi-2 "
176
+ description = "A simple Gradio interface to use a Multi-model (image, text, audio) version of Microsoft Phi-2"
177
+
178
+ demo = gr.Interface(input_multimodel,
179
+ inputs = [gr.Image(label="Input context Image"),
180
+ gr.Audio(label="Input context Audio", source="microphone", type="filepath"),
181
+ gr.Textbox(label="Input context Text"),
182
+ gr.Textbox(label="Input Query"),
183
+ ],
184
+ outputs = [
185
+ gr.Textbox(label='Answer'),
186
+ ],
187
+ title = title,
188
+ description = description,
189
+ )
190
+ demo.launch(share=True)
191
+