File size: 6,946 Bytes
df4965a
2d610a5
 
e0491f4
2d610a5
 
df4965a
 
52a3fd6
 
4b98fcb
df4965a
6dc35f3
 
df4965a
 
 
6dc35f3
df4965a
2d610a5
df4965a
e0491f4
2d610a5
 
df4965a
 
6dc35f3
 
 
df4965a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d610a5
 
 
df4965a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2882d6
df4965a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2882d6
 
df4965a
 
 
 
 
 
 
 
 
 
 
 
 
 
e0491f4
df4965a
 
 
2d610a5
df4965a
 
 
 
 
 
 
 
e2882d6
df4965a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from datetime import timedelta
import os
import uuid
import torch
from rest_framework import status
from rest_framework.response import Response
from rest_framework.generics import CreateAPIView,ListAPIView
from TTS.api import TTS  # Ensure this import is correct based on your TTS library/package
from rest_framework.authentication import TokenAuthentication
from rest_framework.permissions import IsAuthenticated
from texttovoice.models import TextToSpeech
from .serializers import TextToSpeechSerializer, TextToSpeechSerializerResponse ,TextToSpeechSerializerResponseWithURL # Ensure this import matches your file structure
from rest_framework.parsers import MultiPartParser
from drf_yasg.utils import swagger_auto_schema
from drf_yasg import openapi
from rest_framework.exceptions import NotFound as NOT_FOUND 
from .minio_utils import get_minio_client  # Ensure this import matches your file structure

minio_client = get_minio_client()

BUCKET_NAME = "voice-clone"

class TextToSpeechCreateView(CreateAPIView):
    serializer_class = TextToSpeechSerializer
    authentication_classes = [TokenAuthentication]
    permission_classes = [IsAuthenticated]
    parser_classes = [MultiPartParser]

    @swagger_auto_schema(
        operation_id='Create a document',
        operation_description='Create a document by providing file and s3_key',
        manual_parameters=[
            openapi.Parameter('file', openapi.IN_FORM, type=openapi.TYPE_FILE, description='Document to be uploaded'),
            openapi.Parameter('s3_key', openapi.IN_FORM, type=openapi.TYPE_STRING, description='S3 Key of the Document (folders along with name)')
        ],
        responses={
            status.HTTP_200_OK: openapi.Response(
                'Success', schema=openapi.Schema(type=openapi.TYPE_OBJECT, properties={
                    'doc_id': openapi.Schema(type=openapi.TYPE_STRING, description='Document ID'),
                    'mime_type': openapi.Schema(type=openapi.TYPE_STRING, description='Mime Type of the Document'),
                    'version_id': openapi.Schema(type=openapi.TYPE_STRING, description='S3 version ID of the document')
                })
            )
        }
    )
    def create(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        if serializer.is_valid():
            try:
                gpu_available = torch.cuda.is_available()
                text = serializer.validated_data.get("text")
                speaker_wav = serializer.validated_data.get("speaker_wav")
                language = serializer.validated_data.get("language")
                
                # Temporary file paths
                speaker_file_path = os.path.join('/tmp', f"{uuid.uuid4()}{speaker_wav.name}")
                output_filename = os.path.join('/tmp', f"{uuid.uuid4()}.wav")

                # Save speaker WAV file
                with open(speaker_file_path, "wb") as destination:
                    for chunk in speaker_wav.chunks():
                        destination.write(chunk)

                # TTS processing
                tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=gpu_available)
                tts.tts_to_file(text=text, file_path=output_filename, speaker_wav=speaker_file_path, language=language)

                # Upload files to MinIO and cleanup
                public_url,speaker_wav_path  = self.upload_file_to_minio(speaker_file_path, 'speakers/')
                public_url_output ,output_wav_path = self.upload_file_to_minio(output_filename, 'output/')

                # Create DB entry
                tts_instance = TextToSpeech.objects.create(
                    text=text,
                    speaker_wav=speaker_wav_path,
                    output_wav=output_wav_path,
                    language=language,
                    created_by=request.user
                )

                # Serialize and return the created instance
                response_serializer = TextToSpeechSerializerResponse(tts_instance)
                response_data = {
                    **response_serializer.data,
                    "speaker_wav": public_url,       
                    "output_wav": public_url_output  
                }

                return Response(response_data, status=status.HTTP_201_CREATED)
            except Exception as e:
                print("Error due to ",e)
                return Response({"error": "An error occurred processing your request."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
            finally:
                # Ensure cleanup happens
                self.cleanup_files([speaker_file_path, output_filename])
        else:
            return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

    def upload_file_to_minio(self, file_path, prefix):
        """Uploads a file to MinIO and returns a pre-signed URL for secure, temporary access."""
        file_name = os.path.basename(file_path)
        object_name = f"{prefix}{file_name}"
        with open(file_path, "rb") as file_data:
            minio_client.put_object(BUCKET_NAME, object_name, file_data, os.path.getsize(file_path))
        
        # Generate a pre-signed URL for the uploaded object
        pre_signed_url = minio_client.presigned_get_object(BUCKET_NAME, object_name, expires=timedelta(days=1))
        return pre_signed_url ,f"{BUCKET_NAME}/{object_name}"


    def cleanup_files(self, file_paths):
        """Removes files from the filesystem."""
        for file_path in file_paths:
            try:
                os.remove(file_path)
            except Exception as e:
               print(e)
               pass

class TextToSpeechListView(ListAPIView):
    serializer_class = TextToSpeechSerializerResponseWithURL
    authentication_classes = [TokenAuthentication]
    permission_classes = [IsAuthenticated]

    def get_queryset(self):
        return TextToSpeech.objects.filter(created_by=self.request.user)

    def list(self, request, *args, **kwargs):
        queryset = self.get_queryset()

        if not queryset.exists():
            raise NOT_FOUND('No text-to-speech data found for the current user.')

        # Directly serialize the data, pre-signed URLs are handled by the serializer
        serializer = self.get_serializer(queryset, many=True, context={'view': self})
        return Response(serializer.data, status=status.HTTP_200_OK)

    def generate_presigned_url(self, object_path):
        # Ensure this logic correctly splits your `object_path` to get the bucket name and object name
        # This example assumes `object_path` is in the format "bucket_name/object_name"
        try:
            bucket, object_name = object_path.split('/', 1)
            presigned_url = minio_client.presigned_get_object(bucket, object_name, expires=timedelta(hours=1))
            return presigned_url
        except Exception as e:
            print(e)
            return None