Spaces:
Running
on
Zero
Running
on
Zero
Upload app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ MODEL_ID = "Qwen/Qwen-Audio-Chat"
|
|
13 |
QWEN_CHAT_TEMPLATE = """{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] }}{% endif %}{{ eos_token }}{% endfor %}"""
|
14 |
|
15 |
def load_model():
|
|
|
16 |
model = AutoModelForCausalLM.from_pretrained(
|
17 |
MODEL_ID,
|
18 |
torch_dtype=torch.float16,
|
@@ -21,14 +22,17 @@ def load_model():
|
|
21 |
)
|
22 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
23 |
tokenizer.chat_template = QWEN_CHAT_TEMPLATE
|
|
|
24 |
return model, tokenizer
|
25 |
|
26 |
def process_audio(audio_path):
|
27 |
"""Process audio file and return the appropriate format for the model."""
|
28 |
try:
|
|
|
29 |
audio_data, sample_rate = sf.read(audio_path)
|
30 |
if len(audio_data.shape) > 1:
|
31 |
audio_data = audio_data.mean(axis=1) # Convert stereo to mono if necessary
|
|
|
32 |
return True
|
33 |
except Exception as e:
|
34 |
print(f"Error processing audio: {e}")
|
@@ -44,6 +48,8 @@ def analyze_audio(audio_path: str, question: str = None) -> str:
|
|
44 |
Returns:
|
45 |
str: Model's response about the audio
|
46 |
"""
|
|
|
|
|
47 |
# Input validation
|
48 |
if audio_path is None or not isinstance(audio_path, str):
|
49 |
return "Please provide a valid audio file."
|
@@ -54,35 +60,58 @@ def analyze_audio(audio_path: str, question: str = None) -> str:
|
|
54 |
if not process_audio(audio_path):
|
55 |
return "Failed to process the audio file. Please ensure it's a valid audio format."
|
56 |
|
57 |
-
model, tokenizer = load_model()
|
58 |
-
query = question if question else "Please describe what you hear in this audio clip."
|
59 |
-
|
60 |
try:
|
|
|
|
|
|
|
|
|
61 |
messages = [
|
62 |
{
|
63 |
"role": "user",
|
64 |
-
"content": f"<audio>{audio_path}</audio
|
65 |
}
|
66 |
]
|
67 |
|
|
|
68 |
text = tokenizer.apply_chat_template(
|
69 |
messages,
|
70 |
tokenize=False,
|
71 |
add_generation_prompt=True
|
72 |
)
|
|
|
|
|
|
|
73 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
74 |
-
|
|
|
75 |
with torch.no_grad():
|
76 |
outputs = model.generate(
|
77 |
**model_inputs,
|
78 |
max_new_tokens=512,
|
79 |
temperature=0.7,
|
80 |
-
do_sample=True
|
|
|
|
|
|
|
81 |
)
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
except Exception as e:
|
|
|
|
|
|
|
86 |
return f"An error occurred while processing: {str(e)}"
|
87 |
|
88 |
# Create Gradio interface with clear input/output specifications
|
|
|
13 |
QWEN_CHAT_TEMPLATE = """{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] }}{% endif %}{{ eos_token }}{% endfor %}"""
|
14 |
|
15 |
def load_model():
|
16 |
+
print("Loading model and tokenizer...")
|
17 |
model = AutoModelForCausalLM.from_pretrained(
|
18 |
MODEL_ID,
|
19 |
torch_dtype=torch.float16,
|
|
|
22 |
)
|
23 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
24 |
tokenizer.chat_template = QWEN_CHAT_TEMPLATE
|
25 |
+
print("Model and tokenizer loaded successfully")
|
26 |
return model, tokenizer
|
27 |
|
28 |
def process_audio(audio_path):
|
29 |
"""Process audio file and return the appropriate format for the model."""
|
30 |
try:
|
31 |
+
print(f"Processing audio file: {audio_path}")
|
32 |
audio_data, sample_rate = sf.read(audio_path)
|
33 |
if len(audio_data.shape) > 1:
|
34 |
audio_data = audio_data.mean(axis=1) # Convert stereo to mono if necessary
|
35 |
+
print(f"Audio processed successfully. Sample rate: {sample_rate}, Shape: {audio_data.shape}")
|
36 |
return True
|
37 |
except Exception as e:
|
38 |
print(f"Error processing audio: {e}")
|
|
|
48 |
Returns:
|
49 |
str: Model's response about the audio
|
50 |
"""
|
51 |
+
print(f"\nStarting analysis with audio_path: {audio_path}, question: {question}")
|
52 |
+
|
53 |
# Input validation
|
54 |
if audio_path is None or not isinstance(audio_path, str):
|
55 |
return "Please provide a valid audio file."
|
|
|
60 |
if not process_audio(audio_path):
|
61 |
return "Failed to process the audio file. Please ensure it's a valid audio format."
|
62 |
|
|
|
|
|
|
|
63 |
try:
|
64 |
+
model, tokenizer = load_model()
|
65 |
+
query = question if question else "Please describe what you hear in this audio clip."
|
66 |
+
|
67 |
+
print("Preparing messages...")
|
68 |
messages = [
|
69 |
{
|
70 |
"role": "user",
|
71 |
+
"content": f"Here is an audio clip: <audio>{audio_path}</audio>\n{query}"
|
72 |
}
|
73 |
]
|
74 |
|
75 |
+
print("Applying chat template...")
|
76 |
text = tokenizer.apply_chat_template(
|
77 |
messages,
|
78 |
tokenize=False,
|
79 |
add_generation_prompt=True
|
80 |
)
|
81 |
+
print(f"Generated prompt text: {text[:200]}...") # Print first 200 chars of prompt
|
82 |
+
|
83 |
+
print("Tokenizing input...")
|
84 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
85 |
+
|
86 |
+
print("Generating response...")
|
87 |
with torch.no_grad():
|
88 |
outputs = model.generate(
|
89 |
**model_inputs,
|
90 |
max_new_tokens=512,
|
91 |
temperature=0.7,
|
92 |
+
do_sample=True,
|
93 |
+
pad_token_id=tokenizer.pad_token_id,
|
94 |
+
bos_token_id=tokenizer.bos_token_id,
|
95 |
+
eos_token_id=tokenizer.eos_token_id
|
96 |
)
|
97 |
+
|
98 |
+
if outputs is None:
|
99 |
+
print("Model generated None output")
|
100 |
+
return "The model failed to generate a response. Please try again."
|
101 |
+
|
102 |
+
print(f"Output shape: {outputs.shape}")
|
103 |
+
if len(outputs.shape) != 2 or outputs.shape[0] == 0:
|
104 |
+
print(f"Unexpected output shape: {outputs.shape}")
|
105 |
+
return "The model generated an invalid response. Please try again."
|
106 |
+
|
107 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
108 |
+
print(f"Generated response: {response[:200]}...") # Print first 200 chars of response
|
109 |
+
return response
|
110 |
+
|
111 |
except Exception as e:
|
112 |
+
print(f"Error during processing: {str(e)}")
|
113 |
+
import traceback
|
114 |
+
traceback.print_exc()
|
115 |
return f"An error occurred while processing: {str(e)}"
|
116 |
|
117 |
# Create Gradio interface with clear input/output specifications
|