File size: 3,636 Bytes
6db451f
357cae7
40cf25d
6db451f
357cae7
 
 
 
6db451f
 
357cae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2b9472
 
 
357cae7
6db451f
 
 
 
 
b2b9472
40cf25d
 
6db451f
40cf25d
6db451f
 
 
 
 
 
 
8e0e9c0
6db451f
 
b2b9472
 
 
6db451f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

import time
from scipy.io.wavfile import write, read
import numpy as np


# from typing import Union
# from pydantic import BaseModel

from fastapi import FastAPI, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse

# from fastapi.staticfiles import StaticFiles
# from fastapi.responses import FileResponse

import torch
# from transformers import pipeline

from transformers import SeamlessM4Tv2Model
from transformers import AutoProcessor

model_name = "facebook/seamless-m4t-v2-large"
# model_name = "facebook/hf-seamless-m4t-medium"

processor = AutoProcessor.from_pretrained(model_name)
model = SeamlessM4Tv2Model.from_pretrained(model_name)


device = "cuda:0" if torch.cuda.is_available() else "cpu"
# torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model.to(device)

app = FastAPI(docs_url="/api/docs")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
    allow_credentials=True,
)

BATCH_SIZE = 8


@app.get("/device")
def getDevice():
    start_time = time.time()
    print("Time took to process the request and return response is {} sec".format(
        time.time() - start_time))
    return device


@app.get("/translate")
def transcribe(inputs, src_lang="eng", tgt_lang="por"):
    start_time = time.time()

    if inputs is None:
        raise "No audio file submitted! Please upload or record an audio file before submitting your request."

    text_inputs = processor(text=inputs,
                            src_lang=src_lang, return_tensors="pt").to(device)

    output_tokens = model.generate(
        **text_inputs, tgt_lang=tgt_lang, generate_speech=False)

    translated_text_from_text = processor.decode(
        output_tokens[0].tolist()[0], skip_special_tokens=True)

    print("Time took to process the request and return response is {} sec".format(
        time.time() - start_time))
    return translated_text_from_text


@app.get("/audio")
async def audio(inputs, src_lang="eng", tgt_lang="por", speaker_id=5):
    start_time = time.time()

    if inputs is None:
        raise "No audio file submitted! Please upload or record an audio file before submitting your request."

    text_inputs = processor(text=inputs,
                            src_lang=src_lang, return_tensors="pt").to(device)

    audio_array_from_text = model.generate(
        **text_inputs, tgt_lang=tgt_lang, speaker_id=int(speaker_id))[0].cpu().numpy().squeeze()

    write(f"/tmp/output{start_time}.wav", model.config.sampling_rate,
          audio_array_from_text)

    print("Time took to process the request and return response is {} sec".format(
        time.time() - start_time))

    return FileResponse(f"/tmp/output{start_time}.wav", media_type="audio/mpeg")


@app.post("/transcribe-audio")
async def transcribe_audio(soundFile: UploadFile, tgt_lang='eng'):
    start_time = time.time()

    with open(f"/tmp/{soundFile.filename}", "wb") as buffer:
        buffer.write(soundFile.file.read())

    sample_rate, audio_data = read(f"/tmp/{soundFile.filename}")

    audio_inputs = processor(
        audios=audio_data, return_tensors="pt").to(device)

    audio_array_from_audio = model.generate(
        **audio_inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze()

    write(f"/tmp/output{start_time}.wav", model.config.sampling_rate,
          audio_array_from_audio)

    print("Time took to process the request and return response is {} sec".format(
        time.time() - start_time))

    return FileResponse(f"/tmp/output{start_time}.wav", media_type="audio/wav")