Spaces:
Configuration error
Configuration error
import os | |
import uuid | |
import time | |
import logging # Import the logging module | |
import torch | |
from django.http import FileResponse | |
from rest_framework import status | |
from rest_framework.response import Response | |
from rest_framework.generics import CreateAPIView | |
from TTS.api import TTS | |
from rest_framework.authentication import TokenAuthentication | |
from rest_framework.permissions import IsAuthenticated | |
from texttovoice.models import TextToSpeech | |
from .serializers import TextToSpeechSerializer | |
from rest_framework.parsers import MultiPartParser | |
from drf_yasg import openapi | |
from drf_yasg.utils import swagger_auto_schema | |
# Initialize logger at module level | |
logger = logging.getLogger(__name__) | |
class TextToSpeechCreateView(CreateAPIView): | |
serializer_class = TextToSpeechSerializer | |
authentication_classes = [TokenAuthentication] # Apply token authentication | |
permission_classes = [IsAuthenticated] # Require authentication for this view | |
parser_classes = [MultiPartParser] | |
def create(self, request, *args, **kwargs): | |
serializer = self.get_serializer(data=request.data) | |
if serializer.is_valid(): | |
gpu_available = torch.cuda.is_available() | |
text = serializer.validated_data.get("text") | |
speaker_wav = serializer.validated_data.get("speaker_wav") | |
language = serializer.validated_data.get("language") | |
output_filename = f"output_{uuid.uuid4()}.wav" | |
# Log the start time | |
start_time = time.time() | |
# Save the uploaded speaker file to a temporary location | |
speaker_file_path = os.path.join('/tmp', speaker_wav.name) | |
with open(speaker_file_path, "wb") as destination: | |
for chunk in speaker_wav.chunks(): | |
destination.write(chunk) | |
# Generate speech using tts.tts_to_file | |
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=gpu_available) | |
tts.tts_to_file(text=text, file_path=output_filename, speaker_wav=speaker_file_path, language=language) | |
# Log the end time | |
end_time = time.time() | |
# Calculate the processing time | |
processing_time = end_time - start_time | |
# Define a function to delete the output file | |
def file_iterator(file_name): | |
with open(file_name, 'rb') as f: | |
yield from f | |
# Delete the file after sending it | |
try: | |
os.remove(file_name) | |
except Exception as e: | |
# You might want to log this error | |
pass | |
# Use the file_iterator to create a FileResponse | |
TextToSpeech.objects.create( | |
text=text, | |
speaker_wav=speaker_wav, | |
output_wav=output_filename, | |
language=language, | |
created_by=request.user # Assign the authenticated user here | |
) | |
response = FileResponse(file_iterator(output_filename), as_attachment=True, content_type='audio/wav') | |
# Log the processing time using the logger | |
logger.info(f"start time: {start_time} , end time: {end_time} and Processing time: {processing_time} seconds") | |
return response | |
# except Exception as e: | |
# return Response({"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) | |
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) | |