sagar007 commited on
Commit
820ac3d
·
verified ·
1 Parent(s): 1626444

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -264
app.py CHANGED
@@ -1,279 +1,128 @@
 
1
  import torch
2
- import librosa
3
- from transformers import pipeline, WhisperProcessor, WhisperForConditionalGeneration, AutoModelForCausalLM, AutoProcessor
4
- from gtts import gTTS
5
  import gradio as gr
 
6
  from PIL import Image
7
- import logging
8
- import os
9
-
10
- # Set up logging
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
 
14
- # Check for GPU availability
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- logger.info(f"Using device: {device}")
17
-
18
- # Function to safely load pipeline
19
- def load_pipeline(model_name, **kwargs):
20
- try:
21
- return pipeline(model=model_name, device=device, **kwargs)
22
- except Exception as e:
23
- logger.error(f"Error loading {model_name} pipeline: {e}")
24
- return None
25
 
26
- # Load Whisper model for speech recognition
27
- def load_whisper():
28
- try:
29
- processor = WhisperProcessor.from_pretrained("openai/whisper-small")
30
- model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device)
31
- return processor, model
32
- except Exception as e:
33
- logger.error(f"Error loading Whisper model: {e}")
34
- return None, None
35
 
36
- # Load sarvam-2b for text generation
37
- def load_sarvam():
38
- return load_pipeline('sarvamai/sarvam-2b-v0.5')
39
 
40
- # Load vision model
41
- def load_vision_model():
42
- try:
43
- model_id = "microsoft/Phi-3.5-vision-instruct"
44
- model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, torch_dtype="auto", attn_implementation="flash_attention_2").to(device).eval()
45
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
46
- return model, processor
47
- except Exception as e:
48
- logger.error(f"Error loading vision model: {e}")
49
- return None, None
50
 
51
- # Process audio input
52
- def process_audio_input(audio, whisper_processor, whisper_model):
53
- if whisper_processor is None or whisper_model is None:
54
- return "Error: Speech recognition model is not available. Please type your message instead."
55
-
56
- try:
57
- audio, sr = librosa.load(audio, sr=16000)
58
- input_features = whisper_processor(audio, sampling_rate=sr, return_tensors="pt").input_features.to(device)
59
- predicted_ids = whisper_model.generate(input_features)
60
- transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
61
- return transcription
62
- except Exception as e:
63
- logger.error(f"Error processing audio: {e}")
64
- return f"Error processing audio. Please type your message instead."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- # Generate response
67
- def text_to_speech(text, lang='hi'):
68
- try:
69
- # Use a better TTS engine for Indic languages
70
- if lang in ['hi', 'bn', 'gu', 'kn', 'ml', 'mr', 'or', 'pa', 'ta', 'te']:
71
- tts = gTTS(text=text, lang=lang, tld='co.in') # Use Indian TLD
72
- else:
73
- tts = gTTS(text=text, lang=lang)
74
-
75
- output_path = "/tmp/response.mp3"
76
- tts.save(output_path)
77
- return output_path
78
- except Exception as e:
79
- logger.error(f"Error in text-to-speech: {e}")
80
- return None
81
 
82
- # Detect language (placeholder function, replace with actual implementation)
83
- def detect_language(text):
84
- # Implement language detection logic here
85
- return 'en' # Default to English for now
86
 
87
- def generate_response(transcription, sarvam_pipe):
88
- if sarvam_pipe is None:
89
- return "Error: Text generation model is not available."
 
90
 
91
- try:
92
- # Generate response using the sarvam-2b model
93
- response = sarvam_pipe(transcription, max_length=100, num_return_sequences=1)[0]['generated_text']
94
- return response
95
- except Exception as e:
96
- logger.error(f"Error generating response: {e}")
97
- return f"Error generating response. Please try again."
98
-
99
- def process_image(image, text_input, vision_model, vision_processor):
100
- if vision_model is None or vision_processor is None:
101
- return "Error: Vision model is not available."
102
 
103
- try:
104
- prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
105
- image = Image.fromarray(image).convert("RGB")
106
- inputs = vision_processor(prompt, image, return_tensors="pt").to(device)
107
- generate_ids = vision_model.generate(**inputs, max_new_tokens=1000, eos_token_id=vision_processor.tokenizer.eos_token_id)
108
- generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
109
- response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
110
- return response
111
- except Exception as e:
112
- logger.error(f"Error processing image: {e}")
113
- return f"Error processing image. Please try again."
114
-
115
- def multimodal_assistant(input_type, audio_input, text_input, image_input):
116
- try:
117
- # Load models
118
- whisper_processor, whisper_model = load_whisper()
119
- sarvam_pipe = load_sarvam()
120
- vision_model, vision_processor = load_vision_model()
121
-
122
- if input_type == "audio" and audio_input is not None:
123
- transcription = process_audio_input(audio_input, whisper_processor, whisper_model)
124
- elif input_type == "text" and text_input:
125
- transcription = text_input
126
- elif input_type == "image" and image_input is not None:
127
- return process_image(image_input, text_input, vision_model, vision_processor), None
128
- else:
129
- return "Please provide either audio, text, or image input.", None
130
-
131
- response = generate_response(transcription, sarvam_pipe)
132
- lang = detect_language(response)
133
- audio_response = text_to_speech(response, lang)
134
 
135
- return response, audio_response
136
- except Exception as e:
137
- logger.error(f"An error occurred in multimodal_assistant: {e}")
138
- return f"An error occurred. Please try again.", None
139
-
140
- # Custom CSS (you can keep your existing custom CSS here)
141
- custom_css = """
142
- body {
143
- background-color: #0b0f19;
144
- color: #e2e8f0;
145
- font-family: 'Arial', sans-serif;
146
- }
147
- #custom-header {
148
- text-align: center;
149
- padding: 20px 0;
150
- background-color: #1a202c;
151
- margin-bottom: 20px;
152
- border-radius: 10px;
153
- }
154
- #custom-header h1 {
155
- font-size: 2.5rem;
156
- margin-bottom: 0.5rem;
157
- }
158
- #custom-header h1 .blue {
159
- color: #60a5fa;
160
- }
161
- #custom-header h1 .pink {
162
- color: #f472b6;
163
- }
164
- #custom-header h2 {
165
- font-size: 1.5rem;
166
- color: #94a3b8;
167
- }
168
- .suggestions {
169
- display: flex;
170
- justify-content: center;
171
- flex-wrap: wrap;
172
- gap: 1rem;
173
- margin: 20px 0;
174
- }
175
- .suggestion {
176
- background-color: #1e293b;
177
- border-radius: 0.5rem;
178
- padding: 1rem;
179
- display: flex;
180
- align-items: center;
181
- transition: transform 0.3s ease;
182
- width: 200px;
183
- }
184
- .suggestion:hover {
185
- transform: translateY(-5px);
186
- }
187
- .suggestion-icon {
188
- font-size: 1.5rem;
189
- margin-right: 1rem;
190
- background-color: #2d3748;
191
- padding: 0.5rem;
192
- border-radius: 50%;
193
- }
194
- .gradio-container {
195
- max-width: 100% !important;
196
- }
197
- #component-0, #component-1, #component-2 {
198
- max-width: 100% !important;
199
- }
200
- footer {
201
- text-align: center;
202
- margin-top: 2rem;
203
- color: #64748b;
204
- }
205
- """
206
-
207
- # Custom HTML for the header (you can keep your existing custom header here)
208
- custom_header = """
209
- <div id="custom-header">
210
- <h1>
211
- <span class="blue">Multimodal</span>
212
- <span class="pink">Indic Assistant</span>
213
- </h1>
214
- <h2>How can I help you today?</h2>
215
- </div>
216
- """
217
-
218
- # Custom HTML for suggestions
219
- custom_suggestions = """
220
- <div class="suggestions">
221
- <div class="suggestion">
222
- <span class="suggestion-icon">🎤</span>
223
- <p>Speak in any Indic language</p>
224
- </div>
225
- <div class="suggestion">
226
- <span class="suggestion-icon">⌨️</span>
227
- <p>Type in any Indic language</p>
228
- </div>
229
- <div class="suggestion">
230
- <span class="suggestion-icon">📷</span>
231
- <p>Upload an image for analysis</p>
232
- </div>
233
- <div class="suggestion">
234
- <span class="suggestion-icon">🤖</span>
235
- <p>Get AI-generated responses</p>
236
- </div>
237
- <div class="suggestion">
238
- <span class="suggestion-icon">🔊</span>
239
- <p>Listen to audio responses</p>
240
- </div>
241
- </div>
242
- """
243
-
244
- # Create Gradio interface
245
- with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
246
- body_background_fill="#0b0f19",
247
- body_text_color="#e2e8f0",
248
- button_primary_background_fill="#3b82f6",
249
- button_primary_background_fill_hover="#2563eb",
250
- button_primary_text_color="white",
251
- block_title_text_color="#94a3b8",
252
- block_label_text_color="#94a3b8",
253
- )) as iface:
254
- gr.HTML(custom_header)
255
- gr.HTML(custom_suggestions)
256
-
257
- with gr.Row():
258
- with gr.Column(scale=1):
259
- gr.Markdown("### Multimodal Indic Assistant")
260
-
261
- input_type = gr.Radio(["audio", "text", "image"], label="Input Type", value="audio")
262
- audio_input = gr.Audio(type="filepath", label="Speak (if audio input selected)")
263
- text_input = gr.Textbox(label="Type your message or image question")
264
- image_input = gr.Image(label="Upload an image (if image input selected)")
265
-
266
- submit_btn = gr.Button("Submit")
267
-
268
- output_response = gr.Textbox(label="Generated Response")
269
- output_audio = gr.Audio(label="Audio Response")
270
-
271
- submit_btn.click(
272
- fn=multimodal_assistant,
273
- inputs=[input_type, audio_input, text_input, image_input],
274
- outputs=[output_response, output_audio]
275
- )
276
- gr.HTML("<footer>Powered by Multimodal Indic Language AI</footer>")
277
 
278
- # Launch the app
279
- iface.launch()
 
1
+ import os
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer, BitsAndBytesConfig
 
 
4
  import gradio as gr
5
+ from threading import Thread
6
  from PIL import Image
7
+ import subprocess
 
 
 
 
 
8
 
9
+ # Install flash-attention
10
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
 
 
 
 
 
 
 
 
11
 
12
+ # Constants
13
+ TITLE = "<h1><center>Phi 3.5 Multimodal (Text + Vision)</center></h1>"
14
+ DESCRIPTION = "# Phi-3.5 Multimodal Demo (Text + Vision)"
 
 
 
 
 
 
15
 
16
+ # Model configurations
17
+ TEXT_MODEL_ID = "microsoft/Phi-3.5-mini-instruct"
18
+ VISION_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
19
 
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
21
 
22
+ # Quantization config for text model
23
+ quantization_config = BitsAndBytesConfig(
24
+ load_in_4bit=True,
25
+ bnb_4bit_compute_dtype=torch.bfloat16,
26
+ bnb_4bit_use_double_quant=True,
27
+ bnb_4bit_quant_type="nf4"
28
+ )
29
+
30
+ # Load models and tokenizers
31
+ text_tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID)
32
+ text_model = AutoModelForCausalLM.from_pretrained(
33
+ TEXT_MODEL_ID,
34
+ torch_dtype=torch.bfloat16,
35
+ device_map="auto",
36
+ quantization_config=quantization_config
37
+ )
38
+
39
+ vision_model = AutoModelForCausalLM.from_pretrained(
40
+ VISION_MODEL_ID,
41
+ trust_remote_code=True,
42
+ torch_dtype="auto",
43
+ attn_implementation="flash_attention_2"
44
+ ).to(device).eval()
45
+
46
+ vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
47
+
48
+ # Helper functions
49
+ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
50
+ conversation = [{"role": "system", "content": system_prompt}]
51
+ for prompt, answer in history:
52
+ conversation.extend([
53
+ {"role": "user", "content": prompt},
54
+ {"role": "assistant", "content": answer},
55
+ ])
56
+ conversation.append({"role": "user", "content": message})
57
+
58
+ input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
59
+ streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
60
+
61
+ generate_kwargs = dict(
62
+ input_ids=input_ids,
63
+ max_new_tokens=max_new_tokens,
64
+ do_sample=temperature > 0,
65
+ top_p=top_p,
66
+ top_k=top_k,
67
+ temperature=temperature,
68
+ eos_token_id=[128001, 128008, 128009],
69
+ streamer=streamer,
70
+ )
71
 
72
+ with torch.no_grad():
73
+ thread = Thread(target=text_model.generate, kwargs=generate_kwargs)
74
+ thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ buffer = ""
77
+ for new_text in streamer:
78
+ buffer += new_text
79
+ yield buffer
80
 
81
+ def process_vision_query(image, text_input):
82
+ prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
83
+ image = Image.fromarray(image).convert("RGB")
84
+ inputs = vision_processor(prompt, image, return_tensors="pt").to(device)
85
 
86
+ with torch.no_grad():
87
+ generate_ids = vision_model.generate(
88
+ **inputs,
89
+ max_new_tokens=1000,
90
+ eos_token_id=vision_processor.tokenizer.eos_token_id
91
+ )
 
 
 
 
 
92
 
93
+ generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
94
+ response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
95
+ return response
96
+
97
+ # Gradio interface
98
+ with gr.Blocks() as demo:
99
+ gr.HTML(TITLE)
100
+ gr.Markdown(DESCRIPTION)
101
+
102
+ with gr.Tab("Text Model (Phi-3.5-mini)"):
103
+ chatbot = gr.Chatbot(height=600)
104
+ gr.ChatInterface(
105
+ fn=stream_text_chat,
106
+ chatbot=chatbot,
107
+ additional_inputs=[
108
+ gr.Textbox(value="You are a helpful assistant", label="System Prompt"),
109
+ gr.Slider(minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature"),
110
+ gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens"),
111
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p"),
112
+ gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k"),
113
+ ],
114
+ )
115
+
116
+ with gr.Tab("Vision Model (Phi-3.5-vision)"):
117
+ with gr.Row():
118
+ with gr.Column():
119
+ vision_input_img = gr.Image(label="Input Picture")
120
+ vision_text_input = gr.Textbox(label="Question")
121
+ vision_submit_btn = gr.Button(value="Submit")
122
+ with gr.Column():
123
+ vision_output_text = gr.Textbox(label="Output Text")
124
 
125
+ vision_submit_btn.click(process_vision_query, [vision_input_img, vision_text_input], [vision_output_text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ if __name__ == "__main__":
128
+ demo.launch()