import os import torch import librosa import numpy as np import tempfile from fastapi import FastAPI, UploadFile, File, HTTPException from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC from librosa.sequence import dtw app = FastAPI(title="Quran Recitation Comparer API", description="Compares two Quran recitations using a deep wav2vec2 model.", version="1.0") # --- Core Class Definition --- class QuranRecitationComparer: def __init__(self, model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=None): """ Initialize the Quran recitation comparer with a specific Wav2Vec2 model. """ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model and processor once during initialization if auth_token: self.processor = Wav2Vec2Processor.from_pretrained(model_name, token=auth_token) self.model = Wav2Vec2ForCTC.from_pretrained(model_name, token=auth_token) else: self.processor = Wav2Vec2Processor.from_pretrained(model_name) self.model = Wav2Vec2ForCTC.from_pretrained(model_name) self.model = self.model.to(self.device) self.model.eval() # Cache for embeddings to avoid recomputation self.embedding_cache = {} def load_audio(self, file_path, target_sr=16000, trim_silence=True, normalize=True): """Load and preprocess an audio file.""" if not os.path.exists(file_path): raise FileNotFoundError(f"Audio file not found: {file_path}") y, sr = librosa.load(file_path, sr=target_sr) if normalize: y = librosa.util.normalize(y) if trim_silence: y, _ = librosa.effects.trim(y, top_db=30) return y def get_deep_embedding(self, audio, sr=16000): """Extract frame-wise deep embeddings using the pretrained model.""" input_values = self.processor( audio, sampling_rate=sr, return_tensors="pt" ).input_values.to(self.device) with torch.no_grad(): outputs = self.model(input_values, output_hidden_states=True) hidden_states = outputs.hidden_states[-1] embedding_seq = hidden_states.squeeze(0).cpu().numpy() return embedding_seq def compute_dtw_distance(self, features1, features2): """Compute the DTW distance between two sequences of features.""" D, wp = dtw(X=features1, Y=features2, metric='euclidean') distance = D[-1, -1] normalized_distance = distance / len(wp) return normalized_distance def interpret_similarity(self, norm_distance): """Interpret the normalized distance value.""" if norm_distance == 0: result = "The recitations are identical based on the deep embeddings." score = 100 elif norm_distance < 1: result = "The recitations are extremely similar." score = 95 elif norm_distance < 5: result = "The recitations are very similar with minor differences." score = 80 elif norm_distance < 10: result = "The recitations show moderate similarity." score = 60 elif norm_distance < 20: result = "The recitations show some noticeable differences." score = 40 else: result = "The recitations are quite different." score = max(0, 100 - norm_distance) return result, score def get_embedding_for_file(self, file_path): """Get embedding for a file, using cache if available.""" if file_path in self.embedding_cache: return self.embedding_cache[file_path] audio = self.load_audio(file_path) embedding = self.get_deep_embedding(audio) # Store in cache for future use self.embedding_cache[file_path] = embedding return embedding def predict(self, file_path1, file_path2): """ Predict the similarity between two audio files. Args: file_path1 (str): Path to first audio file. file_path2 (str): Path to second audio file. Returns: (float, str): Similarity score and interpretation. """ embedding1 = self.get_embedding_for_file(file_path1) embedding2 = self.get_embedding_for_file(file_path2) norm_distance = self.compute_dtw_distance(embedding1.T, embedding2.T) interpretation, similarity_score = self.interpret_similarity(norm_distance) # Optionally log the results instead of printing in production print(f"Similarity Score: {similarity_score:.1f}/100") print(f"Interpretation: {interpretation}") return similarity_score, interpretation def clear_cache(self): """Clear the embedding cache to free memory.""" self.embedding_cache = {} # --- FastAPI Startup Event --- # In production, consider loading sensitive tokens from environment variables or configuration files. @app.on_event("startup") def startup_event(): global comparer # For production, do not hardcode tokens; use os.environ.get(...) or a configuration system. auth_token = os.environ.get("HF_TOKEN") comparer = QuranRecitationComparer( model_name="jonatasgrosman/wav2vec2-large-xlsr-53-arabic", auth_token=auth_token ) print("Model initialized and ready for predictions!") # --- API Endpoints --- @app.get("/", summary="Health Check") async def root(): return {"message": "Quran Recitation Comparer API is up and running."} @app.post("/predict", summary="Compare Two Audio Files", response_model=dict) async def predict(file1: UploadFile = File(...), file2: UploadFile = File(...)): """ Compare two uploaded audio files and return a similarity score along with an interpretation. - **file1**: The first audio file. - **file2**: The second audio file. """ tmp1_path = None tmp2_path = None try: # Save first file to a temporary location suffix1 = os.path.splitext(file1.filename)[1] or ".wav" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix1) as tmp1: content1 = await file1.read() tmp1.write(content1) tmp1_path = tmp1.name # Save second file to a temporary location suffix2 = os.path.splitext(file2.filename)[1] or ".wav" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix2) as tmp2: content2 = await file2.read() tmp2.write(content2) tmp2_path = tmp2.name similarity_score, interpretation = comparer.predict(tmp1_path, tmp2_path) return {"similarity_score": similarity_score, "interpretation": interpretation} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: # Clean up temporary files if tmp1_path and os.path.exists(tmp1_path): os.remove(tmp1_path) if tmp2_path and os.path.exists(tmp2_path): os.remove(tmp2_path) @app.post("/clear_cache", summary="Clear Embedding Cache", response_model=dict) async def clear_cache(): """ Clear the embedding cache. This can help free memory if many comparisons have been made. """ comparer.clear_cache() return {"message": "Cache cleared."}