desiree commited on
Commit
56709e2
·
verified ·
1 Parent(s): b50b5f5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -9
app.py CHANGED
@@ -13,6 +13,7 @@ MODEL_ID = "Qwen/Qwen-Audio-Chat"
13
  QWEN_CHAT_TEMPLATE = """{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] }}{% endif %}{{ eos_token }}{% endfor %}"""
14
 
15
  def load_model():
 
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_ID,
18
  torch_dtype=torch.float16,
@@ -21,14 +22,17 @@ def load_model():
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
23
  tokenizer.chat_template = QWEN_CHAT_TEMPLATE
 
24
  return model, tokenizer
25
 
26
  def process_audio(audio_path):
27
  """Process audio file and return the appropriate format for the model."""
28
  try:
 
29
  audio_data, sample_rate = sf.read(audio_path)
30
  if len(audio_data.shape) > 1:
31
  audio_data = audio_data.mean(axis=1) # Convert stereo to mono if necessary
 
32
  return True
33
  except Exception as e:
34
  print(f"Error processing audio: {e}")
@@ -44,6 +48,8 @@ def analyze_audio(audio_path: str, question: str = None) -> str:
44
  Returns:
45
  str: Model's response about the audio
46
  """
 
 
47
  # Input validation
48
  if audio_path is None or not isinstance(audio_path, str):
49
  return "Please provide a valid audio file."
@@ -54,35 +60,58 @@ def analyze_audio(audio_path: str, question: str = None) -> str:
54
  if not process_audio(audio_path):
55
  return "Failed to process the audio file. Please ensure it's a valid audio format."
56
 
57
- model, tokenizer = load_model()
58
- query = question if question else "Please describe what you hear in this audio clip."
59
-
60
  try:
 
 
 
 
61
  messages = [
62
  {
63
  "role": "user",
64
- "content": f"<audio>{audio_path}</audio>{query}"
65
  }
66
  ]
67
 
 
68
  text = tokenizer.apply_chat_template(
69
  messages,
70
  tokenize=False,
71
  add_generation_prompt=True
72
  )
 
 
 
73
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
74
-
 
75
  with torch.no_grad():
76
  outputs = model.generate(
77
  **model_inputs,
78
  max_new_tokens=512,
79
  temperature=0.7,
80
- do_sample=True
 
 
 
81
  )
82
-
83
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
84
- return response
 
 
 
 
 
 
 
 
 
 
 
85
  except Exception as e:
 
 
 
86
  return f"An error occurred while processing: {str(e)}"
87
 
88
  # Create Gradio interface with clear input/output specifications
 
13
  QWEN_CHAT_TEMPLATE = """{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] }}{% endif %}{{ eos_token }}{% endfor %}"""
14
 
15
  def load_model():
16
+ print("Loading model and tokenizer...")
17
  model = AutoModelForCausalLM.from_pretrained(
18
  MODEL_ID,
19
  torch_dtype=torch.float16,
 
22
  )
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
24
  tokenizer.chat_template = QWEN_CHAT_TEMPLATE
25
+ print("Model and tokenizer loaded successfully")
26
  return model, tokenizer
27
 
28
  def process_audio(audio_path):
29
  """Process audio file and return the appropriate format for the model."""
30
  try:
31
+ print(f"Processing audio file: {audio_path}")
32
  audio_data, sample_rate = sf.read(audio_path)
33
  if len(audio_data.shape) > 1:
34
  audio_data = audio_data.mean(axis=1) # Convert stereo to mono if necessary
35
+ print(f"Audio processed successfully. Sample rate: {sample_rate}, Shape: {audio_data.shape}")
36
  return True
37
  except Exception as e:
38
  print(f"Error processing audio: {e}")
 
48
  Returns:
49
  str: Model's response about the audio
50
  """
51
+ print(f"\nStarting analysis with audio_path: {audio_path}, question: {question}")
52
+
53
  # Input validation
54
  if audio_path is None or not isinstance(audio_path, str):
55
  return "Please provide a valid audio file."
 
60
  if not process_audio(audio_path):
61
  return "Failed to process the audio file. Please ensure it's a valid audio format."
62
 
 
 
 
63
  try:
64
+ model, tokenizer = load_model()
65
+ query = question if question else "Please describe what you hear in this audio clip."
66
+
67
+ print("Preparing messages...")
68
  messages = [
69
  {
70
  "role": "user",
71
+ "content": f"Here is an audio clip: <audio>{audio_path}</audio>\n{query}"
72
  }
73
  ]
74
 
75
+ print("Applying chat template...")
76
  text = tokenizer.apply_chat_template(
77
  messages,
78
  tokenize=False,
79
  add_generation_prompt=True
80
  )
81
+ print(f"Generated prompt text: {text[:200]}...") # Print first 200 chars of prompt
82
+
83
+ print("Tokenizing input...")
84
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
85
+
86
+ print("Generating response...")
87
  with torch.no_grad():
88
  outputs = model.generate(
89
  **model_inputs,
90
  max_new_tokens=512,
91
  temperature=0.7,
92
+ do_sample=True,
93
+ pad_token_id=tokenizer.pad_token_id,
94
+ bos_token_id=tokenizer.bos_token_id,
95
+ eos_token_id=tokenizer.eos_token_id
96
  )
97
+
98
+ if outputs is None:
99
+ print("Model generated None output")
100
+ return "The model failed to generate a response. Please try again."
101
+
102
+ print(f"Output shape: {outputs.shape}")
103
+ if len(outputs.shape) != 2 or outputs.shape[0] == 0:
104
+ print(f"Unexpected output shape: {outputs.shape}")
105
+ return "The model generated an invalid response. Please try again."
106
+
107
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
108
+ print(f"Generated response: {response[:200]}...") # Print first 200 chars of response
109
+ return response
110
+
111
  except Exception as e:
112
+ print(f"Error during processing: {str(e)}")
113
+ import traceback
114
+ traceback.print_exc()
115
  return f"An error occurred while processing: {str(e)}"
116
 
117
  # Create Gradio interface with clear input/output specifications