desiree commited on
Commit
21a98db
·
verified ·
1 Parent(s): bb4fe56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -50
app.py CHANGED
@@ -12,7 +12,6 @@ import spaces
12
  MODEL_ID = "Qwen/Qwen-Audio-Chat"
13
 
14
  def load_model():
15
- print("Loading model and tokenizer...")
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_ID,
18
  torch_dtype=torch.float16,
@@ -21,7 +20,6 @@ def load_model():
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
23
 
24
- # Define a custom chat template
25
  chat_template = """<s>[INST] <<SYS>>
26
  You are a helpful assistant.
27
  <</SYS>>
@@ -29,67 +27,39 @@ You are a helpful assistant.
29
  {{ message['role'] }}: {{ message['content'] }}
30
  {% endfor %}[/INST]"""
31
 
32
- # Assign the custom chat template to the tokenizer
33
  tokenizer.chat_template = chat_template
34
-
35
- print("Model and tokenizer loaded successfully")
36
  return model, tokenizer
37
 
38
  def process_audio(audio_path):
39
- """Process audio file for the model."""
40
  try:
41
- print(f"Processing audio file: {audio_path}")
42
- # Read audio file
43
  audio_data, sample_rate = sf.read(audio_path)
44
 
45
- # Convert to mono if stereo
46
  if len(audio_data.shape) > 1:
47
  audio_data = audio_data.mean(axis=1)
48
 
49
- # Ensure float32 format
50
  audio_data = audio_data.astype(np.float32)
51
 
52
- # Create in-memory buffer
53
  audio_buffer = BytesIO()
54
-
55
- # Write audio to buffer in WAV format
56
  sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
57
 
58
- # Get the buffer content and encode to base64
59
  audio_buffer.seek(0)
60
  audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
61
 
62
- print(f"Audio processed successfully. Sample rate: {sample_rate}, Shape: {audio_data.shape}")
63
  return {
64
  "audio": audio_base64,
65
  "sampling_rate": sample_rate
66
  }
67
- except Exception as e:
68
- print(f"Error processing audio: {e}")
69
- import traceback
70
- traceback.print_exc()
71
  return None
72
 
73
  @spaces.GPU
74
  def analyze_audio(audio_path: str, question: str = None) -> str:
75
- """
76
- Main function for audio analysis that will be exposed as a tool.
77
- Args:
78
- audio_path: Path to the audio file
79
- question: Optional question about the audio
80
- Returns:
81
- str: Model's response about the audio
82
- """
83
- print(f"\nStarting analysis with audio_path: {audio_path}, question: {question}")
84
-
85
- # Input validation
86
  if audio_path is None or not isinstance(audio_path, str):
87
  return "Please provide a valid audio file."
88
 
89
  if not os.path.exists(audio_path):
90
  return f"Audio file not found: {audio_path}"
91
 
92
- # Process audio
93
  audio_data = process_audio(audio_path)
94
  if not audio_data or "audio" not in audio_data or "sampling_rate" not in audio_data:
95
  return "Failed to process the audio file. Please ensure it's a valid audio format."
@@ -98,7 +68,6 @@ def analyze_audio(audio_path: str, question: str = None) -> str:
98
  model, tokenizer = load_model()
99
  query = question if question else "Please describe what you hear in this audio clip."
100
 
101
- print("Preparing messages...")
102
  messages = [
103
  {
104
  "role": "user",
@@ -106,7 +75,6 @@ def analyze_audio(audio_path: str, question: str = None) -> str:
106
  }
107
  ]
108
 
109
- print("Applying chat template...")
110
  if tokenizer.chat_template:
111
  text = tokenizer.apply_chat_template(
112
  messages,
@@ -116,12 +84,8 @@ def analyze_audio(audio_path: str, question: str = None) -> str:
116
  else:
117
  raise ValueError("Tokenizer chat_template is not set.")
118
 
119
- print(f"Generated prompt text: {text[:200]}...")
120
-
121
- print("Tokenizing input...")
122
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
123
 
124
- print("Generating response...")
125
  with torch.no_grad():
126
  outputs = model.generate(
127
  **model_inputs,
@@ -134,24 +98,14 @@ def analyze_audio(audio_path: str, question: str = None) -> str:
134
  )
135
 
136
  if outputs is None or len(outputs) == 0:
137
- print("Model generated None or empty output")
138
  return "The model failed to generate a response. Please try again."
139
 
140
- print(f"Output shape: {outputs.shape}")
141
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
142
- print(f"Generated response: {response[:200]}...")
143
  return response
144
 
145
- except TypeError as te:
146
- print(f"TypeError during processing: {str(te)}")
147
- return "An error occurred with the data processing. Please check the inputs."
148
- except Exception as e:
149
- print(f"Error during processing: {str(e)}")
150
- import traceback
151
- traceback.print_exc()
152
- return f"An error occurred while processing: {str(e)}"
153
 
154
- # Create Gradio interface with clear input/output specifications
155
  demo = gr.Interface(
156
  fn=analyze_audio,
157
  inputs=[
@@ -159,7 +113,7 @@ demo = gr.Interface(
159
  type="filepath",
160
  label="Audio Input",
161
  sources=["upload", "microphone"],
162
- format="mp3" # Specify format to ensure consistent audio format
163
  ),
164
  gr.Textbox(
165
  label="Question",
 
12
  MODEL_ID = "Qwen/Qwen-Audio-Chat"
13
 
14
  def load_model():
 
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_ID,
17
  torch_dtype=torch.float16,
 
20
  )
21
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
22
 
 
23
  chat_template = """<s>[INST] <<SYS>>
24
  You are a helpful assistant.
25
  <</SYS>>
 
27
  {{ message['role'] }}: {{ message['content'] }}
28
  {% endfor %}[/INST]"""
29
 
 
30
  tokenizer.chat_template = chat_template
 
 
31
  return model, tokenizer
32
 
33
  def process_audio(audio_path):
 
34
  try:
 
 
35
  audio_data, sample_rate = sf.read(audio_path)
36
 
 
37
  if len(audio_data.shape) > 1:
38
  audio_data = audio_data.mean(axis=1)
39
 
 
40
  audio_data = audio_data.astype(np.float32)
41
 
 
42
  audio_buffer = BytesIO()
 
 
43
  sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
44
 
 
45
  audio_buffer.seek(0)
46
  audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
47
 
 
48
  return {
49
  "audio": audio_base64,
50
  "sampling_rate": sample_rate
51
  }
52
+ except Exception:
 
 
 
53
  return None
54
 
55
  @spaces.GPU
56
  def analyze_audio(audio_path: str, question: str = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
57
  if audio_path is None or not isinstance(audio_path, str):
58
  return "Please provide a valid audio file."
59
 
60
  if not os.path.exists(audio_path):
61
  return f"Audio file not found: {audio_path}"
62
 
 
63
  audio_data = process_audio(audio_path)
64
  if not audio_data or "audio" not in audio_data or "sampling_rate" not in audio_data:
65
  return "Failed to process the audio file. Please ensure it's a valid audio format."
 
68
  model, tokenizer = load_model()
69
  query = question if question else "Please describe what you hear in this audio clip."
70
 
 
71
  messages = [
72
  {
73
  "role": "user",
 
75
  }
76
  ]
77
 
 
78
  if tokenizer.chat_template:
79
  text = tokenizer.apply_chat_template(
80
  messages,
 
84
  else:
85
  raise ValueError("Tokenizer chat_template is not set.")
86
 
 
 
 
87
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
88
 
 
89
  with torch.no_grad():
90
  outputs = model.generate(
91
  **model_inputs,
 
98
  )
99
 
100
  if outputs is None or len(outputs) == 0:
 
101
  return "The model failed to generate a response. Please try again."
102
 
 
103
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
104
  return response
105
 
106
+ except Exception:
107
+ return "An error occurred while processing. Please check your inputs and try again."
 
 
 
 
 
 
108
 
 
109
  demo = gr.Interface(
110
  fn=analyze_audio,
111
  inputs=[
 
113
  type="filepath",
114
  label="Audio Input",
115
  sources=["upload", "microphone"],
116
+ format="mp3"
117
  ),
118
  gr.Textbox(
119
  label="Question",