Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, UploadFile, File | |
from pydantic import BaseModel | |
import torch | |
import librosa | |
import numpy as np | |
import os | |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
import tempfile | |
import shutil | |
from dotenv import load_dotenv | |
import uvicorn | |
import scipy.spatial.distance as distance | |
# Load environment variables | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
app = FastAPI(title="Quran Recitation Comparer API") | |
class ComparisonResult(BaseModel): | |
similarity_score: float | |
interpretation: str | |
# Custom implementation of DTW to replace librosa.sequence.dtw | |
def custom_dtw(X, Y, metric='euclidean'): | |
""" | |
Custom Dynamic Time Warping implementation. | |
Args: | |
X: First sequence | |
Y: Second sequence | |
metric: Distance metric ('euclidean' or 'cosine') | |
Returns: | |
D: Cost matrix | |
wp: Warping path | |
""" | |
# Get sequence lengths | |
n, m = len(X), len(Y) | |
# Initialize cost matrix | |
D = np.zeros((n + 1, m + 1)) | |
D[0, 1:] = np.inf | |
D[1:, 0] = np.inf | |
D[0, 0] = 0 | |
# Fill cost matrix | |
for i in range(1, n + 1): | |
for j in range(1, m + 1): | |
if metric == 'euclidean': | |
cost = np.sum((X[i-1] - Y[j-1])**2) | |
elif metric == 'cosine': | |
cost = 1 - np.dot(X[i-1], Y[j-1]) / (np.linalg.norm(X[i-1]) * np.linalg.norm(Y[j-1])) | |
D[i, j] = cost + min(D[i-1, j], D[i, j-1], D[i-1, j-1]) | |
# Backtracking | |
wp = [(n, m)] | |
i, j = n, m | |
while i > 0 or j > 0: | |
if i == 0: | |
j -= 1 | |
elif j == 0: | |
i -= 1 | |
else: | |
min_idx = np.argmin([D[i-1, j-1], D[i-1, j], D[i, j-1]]) | |
if min_idx == 0: | |
i -= 1 | |
j -= 1 | |
elif min_idx == 1: | |
i -= 1 | |
else: | |
j -= 1 | |
wp.append((i, j)) | |
wp.reverse() | |
return D, wp | |
class QuranRecitationComparer: | |
def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", token=None): | |
"""Initialize the Quran recitation comparer with a specific Wav2Vec2 model.""" | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {self.device}") | |
# Load model and processor once during initialization | |
if token: | |
print(f"Loading model {model_name} with token...") | |
self.processor = Wav2Vec2Processor.from_pretrained(model_name, use_auth_token=token) | |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name, use_auth_token=token) | |
else: | |
print(f"Loading model {model_name} without token...") | |
self.processor = Wav2Vec2Processor.from_pretrained(model_name) | |
self.model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
self.model = self.model.to(self.device) | |
self.model.eval() | |
# Cache for embeddings to avoid recomputation | |
self.embedding_cache = {} | |
print("Model loaded successfully!") | |
def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True): | |
"""Load and preprocess an audio file.""" | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Audio file not found: {file_path}") | |
print(f"Loading audio: {file_path}") | |
y, sr = librosa.load(file_path, sr=target_sr) | |
if normalize: | |
y = librosa.util.normalize(y) | |
if trim_silence: | |
# Use librosa.effects.trim which should be available in most versions | |
y, _ = librosa.effects.trim(y, top_db=30) | |
return y | |
def get_deep_embedding(self, audio, sr=16000): | |
"""Extract frame-wise deep embeddings using the pretrained model.""" | |
input_values = self.processor( | |
audio, | |
sampling_rate=sr, | |
return_tensors="pt" | |
).input_values.to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(input_values, output_hidden_states=True) | |
hidden_states = outputs.hidden_states[-1] | |
embedding_seq = hidden_states.squeeze(0).cpu().numpy() | |
return embedding_seq | |
def compute_dtw_distance(self, features1, features2): | |
"""Compute the DTW distance between two sequences of features.""" | |
D, wp = custom_dtw(X=features1, Y=features2, metric='euclidean') | |
distance = D[-1, -1] | |
normalized_distance = distance / len(wp) | |
return normalized_distance | |
def interpret_similarity(self, norm_distance): | |
"""Interpret the normalized distance value.""" | |
if norm_distance == 0: | |
result = "The recitations are identical based on the deep embeddings." | |
score = 100 | |
elif norm_distance < 1: | |
result = "The recitations are extremely similar." | |
score = 95 | |
elif norm_distance < 5: | |
result = "The recitations are very similar with minor differences." | |
score = 80 | |
elif norm_distance < 10: | |
result = "The recitations show moderate similarity." | |
score = 60 | |
elif norm_distance < 20: | |
result = "The recitations show some noticeable differences." | |
score = 40 | |
else: | |
result = "The recitations are quite different." | |
score = max(0, 100 - norm_distance) | |
return result, score | |
def get_embedding_for_file(self, file_path): | |
"""Get embedding for a file, using cache if available.""" | |
if file_path in self.embedding_cache: | |
print(f"Using cached embedding for {file_path}") | |
return self.embedding_cache[file_path] | |
print(f"Computing new embedding for {file_path}") | |
audio = self.load_audio(file_path) | |
embedding = self.get_deep_embedding(audio) | |
# Store in cache for future use | |
self.embedding_cache[file_path] = embedding | |
print(f"Embedding shape: {embedding.shape}") | |
return embedding | |
def predict(self, file_path1, file_path2): | |
""" | |
Predict the similarity between two audio files. | |
This method can be called repeatedly without reloading the model. | |
Args: | |
file_path1 (str): Path to first audio file | |
file_path2 (str): Path to second audio file | |
Returns: | |
float: Similarity score | |
str: Interpretation of similarity | |
""" | |
print(f"Comparing {file_path1} and {file_path2}") | |
# Get embeddings (using cache if available) | |
embedding1 = self.get_embedding_for_file(file_path1) | |
embedding2 = self.get_embedding_for_file(file_path2) | |
# Compute DTW distance | |
print("Computing DTW distance...") | |
norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T) | |
print(f"Normalized distance: {norm_distance}") | |
# Interpret results | |
interpretation, similarity_score = self.interpret_similarity(norm_distance) | |
print(f"Similarity score: {similarity_score}, Interpretation: {interpretation}") | |
return similarity_score, interpretation | |
def clear_cache(self): | |
"""Clear the embedding cache to free memory.""" | |
self.embedding_cache = {} | |
print("Embedding cache cleared") | |
# Global variable for the comparer instance | |
comparer = None | |
async def startup_event(): | |
"""Initialize the model when the application starts.""" | |
global comparer | |
print("Initializing model... This may take a moment.") | |
try: | |
comparer = QuranRecitationComparer( | |
model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", | |
token=HF_TOKEN | |
) | |
print("Model initialized and ready for predictions!") | |
except Exception as e: | |
print(f"Error initializing model: {str(e)}") | |
raise | |
async def root(): | |
"""Root endpoint to check if the API is running.""" | |
return {"message": "Quran Recitation Comparer API is running", "status": "active"} | |
async def compare_files( | |
file1: UploadFile = File(...), | |
file2: UploadFile = File(...) | |
): | |
""" | |
Compare two audio files and return similarity metrics. | |
- **file1**: First audio file (MP3, WAV, etc.) | |
- **file2**: Second audio file (MP3, WAV, etc.) | |
Returns similarity score and interpretation. | |
""" | |
if not comparer: | |
raise HTTPException(status_code=500, detail="Model not initialized. Please try again later.") | |
print(f"Received files: {file1.filename} and {file2.filename}") | |
temp_dir = tempfile.mkdtemp() | |
print(f"Created temporary directory: {temp_dir}") | |
try: | |
# Save uploaded files to temporary directory | |
temp_file1 = os.path.join(temp_dir, file1.filename) | |
temp_file2 = os.path.join(temp_dir, file2.filename) | |
with open(temp_file1, "wb") as f: | |
content = await file1.read() | |
f.write(content) | |
with open(temp_file2, "wb") as f: | |
content = await file2.read() | |
f.write(content) | |
print(f"Files saved to: {temp_file1} and {temp_file2}") | |
# Compare the files | |
similarity_score, interpretation = comparer.predict(temp_file1, temp_file2) | |
return ComparisonResult( | |
similarity_score=similarity_score, | |
interpretation=interpretation | |
) | |
except Exception as e: | |
print(f"Error processing files: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error processing files: {str(e)}") | |
finally: | |
# Clean up temporary files | |
print(f"Cleaning up temporary directory: {temp_dir}") | |
shutil.rmtree(temp_dir, ignore_errors=True) | |
async def clear_cache(): | |
"""Clear the embedding cache to free memory.""" | |
if not comparer: | |
raise HTTPException(status_code=500, detail="Model not initialized.") | |
comparer.clear_cache() | |
return {"message": "Embedding cache cleared successfully"} | |
if __name__ == "__main__": | |
uvicorn.run("main:app", host="0.0.0.0", port=7860, log_level="info") |