import fastapi import shutil import os import zipfile import io import uvicorn import glob from typing import List class ModelAPI: def __init__(self, host, port): self.host = host self.port = port self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi") self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio") self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio") # Create directories if they do not exist for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]: if not os.path.exists(audio_path): os.makedirs(audio_path) # Loop through all the files and subdirectories in the directory for filename in os.listdir(audio_path): file_path = os.path.join(audio_path, filename) # Check if it's a file or directory and remove accordingly try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) # Remove the file or link elif os.path.isdir(file_path): shutil.rmtree(file_path) # Remove the directory and its contents except Exception as e: raise e self.app = fastapi.FastAPI() self._setup_routes() def _prepare(self): """Miners should modify this function to fit their fine-tuned models. This function will make any preparations necessary to initialize the speech enhancement model (i.e. downloading checkpoint files, etc.) """ # Continue from here pass def _enhance(self): """ Miners should modify this function to fit their fine-tuned models. This function will: 1. Open each noisy .wav file 2. Enhance the audio with the model 3. Save the enhanced audio in .wav format to MinerAPI.enhanced_audio_path """ # Define file paths for all noisy files to be enhanced noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, '*.wav'))) for noisy_file in noisy_files: # Continue from here pass def _setup_routes(self): """ Setup API routes: /status/ : Communicates API status /upload-audio/ : Upload audio files, save to noisy audio directory /enhance/ : Enhance audio files, save to enhanced audio directory /download-enhanced/ : Download enhanced audio files """ self.app.get("/status/")(self.get_status) self.app.post("/prepare/")(self.prepare) self.app.post("/upload-audio/")(self.upload_audio) self.app.post("/enhance/")(self.enhance_audio) self.app.get("/download-enhanced/")(self.download_enhanced) def get_status(self): try: return {"container_running": True} except: raise fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.") def prepare(self): try: self._prepare() return {'preparations': True} except: return fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.") def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)): uploaded_files = [] for file in files: try: # Define the path to save the file file_path = os.path.join(self.noisy_audio_path, file.filename) # Save the uploaded file with open(file_path, "wb") as f: while contents := file.file.read(1024*1024): f.write(contents) # Append the file name to the list of uploaded files uploaded_files.append(file.filename) except: raise fastapi.HTTPException(status_code=500, detail="An error occurred while uploading the noisy files.") finally: file.file.close() return {"uploaded_files": uploaded_files, "status": True} def enhance_audio(self): try: # Enhance audio self._enhance() # Obtain list of file paths for enhanced audio wav_files = glob.glob(os.path.join(self.enhanced_audio_path, '*.wav')) # Extract just the file names enhanced_files = [os.path.basename(file) for file in wav_files] return {"status": True} except Exception as e: raise fastapi.HTTPException(status_code=500, detail="An error occurred while enhancing the noisy files.") def download_enhanced(self): try: # Create an in-memory zip file to hold all the enhanced audio files zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w") as zip_file: # Add each .wav file in the enhanced_audio_path directory to the zip file for wav_file in glob.glob(os.path.join(self.enhanced_audio_path, '*.wav')): zip_file.write(wav_file, arcname=os.path.basename(wav_file)) # Make sure to seek back to the start of the BytesIO object before sending it zip_buffer.seek(0) # Send the zip file to the client as a downloadable file return fastapi.responses.StreamingResponse( iter([zip_buffer.getvalue()]), # Stream the in-memory content media_type="application/zip", headers={"Content-Disposition": "attachment; filename=enhanced_audio_files.zip"} ) except Exception as e: # Log the error if needed, and raise an HTTPException to inform the client raise fastapi.HTTPException(status_code=500, detail=f"An error occurred while creating the download file: {str(e)}") def run(self): uvicorn.run(self.app, host=self.host, port=self.port)