Spaces:
Configuration error
Configuration error
File size: 6,946 Bytes
df4965a 2d610a5 e0491f4 2d610a5 df4965a 52a3fd6 4b98fcb df4965a 6dc35f3 df4965a 6dc35f3 df4965a 2d610a5 df4965a e0491f4 2d610a5 df4965a 6dc35f3 df4965a 2d610a5 df4965a e2882d6 df4965a e2882d6 df4965a e0491f4 df4965a 2d610a5 df4965a e2882d6 df4965a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from datetime import timedelta
import os
import uuid
import torch
from rest_framework import status
from rest_framework.response import Response
from rest_framework.generics import CreateAPIView,ListAPIView
from TTS.api import TTS # Ensure this import is correct based on your TTS library/package
from rest_framework.authentication import TokenAuthentication
from rest_framework.permissions import IsAuthenticated
from texttovoice.models import TextToSpeech
from .serializers import TextToSpeechSerializer, TextToSpeechSerializerResponse ,TextToSpeechSerializerResponseWithURL # Ensure this import matches your file structure
from rest_framework.parsers import MultiPartParser
from drf_yasg.utils import swagger_auto_schema
from drf_yasg import openapi
from rest_framework.exceptions import NotFound as NOT_FOUND
from .minio_utils import get_minio_client # Ensure this import matches your file structure
minio_client = get_minio_client()
BUCKET_NAME = "voice-clone"
class TextToSpeechCreateView(CreateAPIView):
serializer_class = TextToSpeechSerializer
authentication_classes = [TokenAuthentication]
permission_classes = [IsAuthenticated]
parser_classes = [MultiPartParser]
@swagger_auto_schema(
operation_id='Create a document',
operation_description='Create a document by providing file and s3_key',
manual_parameters=[
openapi.Parameter('file', openapi.IN_FORM, type=openapi.TYPE_FILE, description='Document to be uploaded'),
openapi.Parameter('s3_key', openapi.IN_FORM, type=openapi.TYPE_STRING, description='S3 Key of the Document (folders along with name)')
],
responses={
status.HTTP_200_OK: openapi.Response(
'Success', schema=openapi.Schema(type=openapi.TYPE_OBJECT, properties={
'doc_id': openapi.Schema(type=openapi.TYPE_STRING, description='Document ID'),
'mime_type': openapi.Schema(type=openapi.TYPE_STRING, description='Mime Type of the Document'),
'version_id': openapi.Schema(type=openapi.TYPE_STRING, description='S3 version ID of the document')
})
)
}
)
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
try:
gpu_available = torch.cuda.is_available()
text = serializer.validated_data.get("text")
speaker_wav = serializer.validated_data.get("speaker_wav")
language = serializer.validated_data.get("language")
# Temporary file paths
speaker_file_path = os.path.join('/tmp', f"{uuid.uuid4()}{speaker_wav.name}")
output_filename = os.path.join('/tmp', f"{uuid.uuid4()}.wav")
# Save speaker WAV file
with open(speaker_file_path, "wb") as destination:
for chunk in speaker_wav.chunks():
destination.write(chunk)
# TTS processing
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=gpu_available)
tts.tts_to_file(text=text, file_path=output_filename, speaker_wav=speaker_file_path, language=language)
# Upload files to MinIO and cleanup
public_url,speaker_wav_path = self.upload_file_to_minio(speaker_file_path, 'speakers/')
public_url_output ,output_wav_path = self.upload_file_to_minio(output_filename, 'output/')
# Create DB entry
tts_instance = TextToSpeech.objects.create(
text=text,
speaker_wav=speaker_wav_path,
output_wav=output_wav_path,
language=language,
created_by=request.user
)
# Serialize and return the created instance
response_serializer = TextToSpeechSerializerResponse(tts_instance)
response_data = {
**response_serializer.data,
"speaker_wav": public_url,
"output_wav": public_url_output
}
return Response(response_data, status=status.HTTP_201_CREATED)
except Exception as e:
print("Error due to ",e)
return Response({"error": "An error occurred processing your request."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
finally:
# Ensure cleanup happens
self.cleanup_files([speaker_file_path, output_filename])
else:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def upload_file_to_minio(self, file_path, prefix):
"""Uploads a file to MinIO and returns a pre-signed URL for secure, temporary access."""
file_name = os.path.basename(file_path)
object_name = f"{prefix}{file_name}"
with open(file_path, "rb") as file_data:
minio_client.put_object(BUCKET_NAME, object_name, file_data, os.path.getsize(file_path))
# Generate a pre-signed URL for the uploaded object
pre_signed_url = minio_client.presigned_get_object(BUCKET_NAME, object_name, expires=timedelta(days=1))
return pre_signed_url ,f"{BUCKET_NAME}/{object_name}"
def cleanup_files(self, file_paths):
"""Removes files from the filesystem."""
for file_path in file_paths:
try:
os.remove(file_path)
except Exception as e:
print(e)
pass
class TextToSpeechListView(ListAPIView):
serializer_class = TextToSpeechSerializerResponseWithURL
authentication_classes = [TokenAuthentication]
permission_classes = [IsAuthenticated]
def get_queryset(self):
return TextToSpeech.objects.filter(created_by=self.request.user)
def list(self, request, *args, **kwargs):
queryset = self.get_queryset()
if not queryset.exists():
raise NOT_FOUND('No text-to-speech data found for the current user.')
# Directly serialize the data, pre-signed URLs are handled by the serializer
serializer = self.get_serializer(queryset, many=True, context={'view': self})
return Response(serializer.data, status=status.HTTP_200_OK)
def generate_presigned_url(self, object_path):
# Ensure this logic correctly splits your `object_path` to get the bucket name and object name
# This example assumes `object_path` is in the format "bucket_name/object_name"
try:
bucket, object_name = object_path.split('/', 1)
presigned_url = minio_client.presigned_get_object(bucket, object_name, expires=timedelta(hours=1))
return presigned_url
except Exception as e:
print(e)
return None
|