deploy-s2s-api / main.py
3v324v23's picture
Add application file
72d1130
raw
history blame
12.8 kB
from fastapi import FastAPI
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from s2smodels import Base, Audio_segment, AudioGeneration
from pydub import AudioSegment
import os
from fastapi import FastAPI, Response
import torch
from fastapi.responses import JSONResponse
from utils.prompt_making import make_prompt
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from utils.generation import SAMPLE_RATE, generate_audio, preload_models
from io import BytesIO
from pyannote.audio import Pipeline
import soundfile as sf
from fastapi_cors import CORS
DATABASE_URL = "sqlite:///./sql_app.db"
engine = create_engine(DATABASE_URL)
Session = sessionmaker(bind=engine)
app = FastAPI()
"""
origins = ["*"]
app.add_middleware(
CORS,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
"""
Base.metadata.create_all(engine)
@app.get("/")
def root():
return {"message": "No result"}
#add audio segements in Audio_segment Table
def create_segment(start_time: float, end_time: float, audio: AudioSegment, type: str):
session = Session()
audio_bytes = BytesIO()
audio.export(audio_bytes, format='wav')
audio_bytes = audio_bytes.getvalue()
segment = Audio_segment(start_time=start_time, end_time=end_time, type=type, audio=audio_bytes)
session.add(segment)
session.commit()
session.close()
return {"status_code": 200, "message": "success"}
#add target audio to AudioGeneration Table
def generate_target(audio: AudioSegment):
session = Session()
audio_bytes = BytesIO()
audio.export(audio_bytes, format='wav')
audio_bytes = audio_bytes.getvalue()
target_audio = AudioGeneration(audio=audio_bytes)
session.add(target_audio)
session.commit()
session.close()
return {"status_code": 200, "message": "success"}
"""
audio segmentation into speech and non-speech using segmentation model
"""
def audio_speech_nonspeech_detection(audio_url):
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0"
)
diarization = pipeline(audio_url)
speaker_regions=[]
for turn, _,speaker in diarization.itertracks(yield_label=True):
speaker_regions.append({"start":turn.start,"end":turn.end})
sound = AudioSegment.from_wav(audio_url)
speaker_regions.sort(key=lambda x: x['start'])
non_speech_regions = []
for i in range(1, len(speaker_regions)):
start = speaker_regions[i-1]['end']
end = speaker_regions[i]['start']
if end > start:
non_speech_regions.append({'start': start, 'end': end})
first_speech_start = speaker_regions[0]['start']
if first_speech_start > 0:
non_speech_regions.insert(0, {'start': 0, 'end': first_speech_start})
last_speech_end = speaker_regions[-1]['end']
total_audio_duration = len(sound)
if last_speech_end < total_audio_duration:
non_speech_regions.append({'start': last_speech_end, 'end': total_audio_duration})
return speaker_regions,non_speech_regions
"""
save speech and non-speech segments in audio_segment table
"""
def split_audio_segments(audio_url):
sound = AudioSegment.from_wav(audio_url)
speech_segments, non_speech_segment = audio_speech_nonspeech_detection(audio_url)
# Process speech segments
for i, speech_segment in enumerate(speech_segments):
start = int(speech_segment['start'] * 1000)
end = int(speech_segment['end'] * 1000)
segment = sound[start:end]
create_segment(start_time=start/1000,
end_time=end/1000,
type="speech",audio=segment)
# Process non-speech segments
for i, non_speech_segment in enumerate(non_speech_segment):
start = int(non_speech_segment['start'] * 1000)
end = int(non_speech_segment['end'] * 1000)
segment = sound[start:end]
create_segment(start_time=start/1000,
end_time=end/1000,
type="non-speech",audio=segment)
#@app.post("/translate_en_ar/")
def en_text_to_ar_text_translation(text):
pipe = pipeline("translation", model="facebook/nllb-200-distilled-600M")
result=pipe(text,src_lang='English',tgt_lang='Egyptain Arabic')
return result[0]['translation_text']
def make_prompt_audio(name,audio_path):
make_prompt(name=name, audio_prompt_path=audio_path)
# whisper model for speech to text process (english language)
#@app.post("/en_speech_ar_text/")
def en_speech_to_en_text_process(segment):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=True,
torch_dtype=torch_dtype,
device=device,
)
result = pipe(segment)
return result["text"]
#text to speech using VALL-E-X model
#@app.post("/text_to_speech/")
def text_to_speech(segment_id, target_text, audio_prompt):
preload_models()
session = Session()
segment = session.query(Audio_segment).get(segment_id)
make_prompt_audio(name=f"audio_{segment_id}",audio_path=audio_prompt)
audio_array = generate_audio(target_text,f"audio_{segment_id}")
temp_file = BytesIO()
sf.write(temp_file, audio_array, SAMPLE_RATE, format='wav')
temp_file.seek(0)
segment.audio = temp_file.getvalue()
session.commit()
session.close()
temp_file.close()
#os.remove(temp_file)
"""
reconstruct target audio using all updated segment
in audio_segment table and then remove all audio_Segment records
"""
def construct_audio():
session = Session()
# Should be ordered by start_time
segments = session.query(Audio_segment).order_by('start_time').all()
audio_files = []
for segment in segments:
audio_files.append(AudioSegment.from_file(BytesIO(segment.audio), format='wav'))
target_audio = sum(audio_files, AudioSegment.empty())
generate_target(audio=target_audio)
# Delete all records in Audio_segment table
session.query(Audio_segment).delete()
session.commit()
session.close()
"""
source => english speech
target => arabic speeech
"""
#@app.post("/en_speech_ar_speech/")
def speech_to_speech_translation_en_ar(audio_url):
session=Session()
target_text=None
split_audio_segments(audio_url)
#filtering by type
speech_segments = session.query(Audio_segment).filter(Audio_segment.type == "speech").all()
for segment in speech_segments:
audio_data = segment.audio
text = en_speech_to_en_text_process(audio_data)
if text:
target_text=en_text_to_ar_text_translation(text)
else:
print("speech_to_text_process function not return result. ")
if target_text is None:
print("Target text is None.")
else:
segment_id = segment.id
segment_duration = segment.end_time - segment.start_time
if segment_duration <=15:
text_to_speech(segment_id,target_text,segment.audio)
else:
audio_data=extract_15_seconds(segment.audio,segment.start_time,segment.end_time)
text_to_speech(segment_id,target_text,audio_data)
os.remove(audio_data)
construct_audio()
return JSONResponse(status_code=200, content={"status_code":"succcessfully"})
@app.get("/get_ar_audio/")
async def get_ar_audio(audio_url):
speech_to_speech_translation_en_ar(audio_url)
session = Session()
# Get target audio from AudioGeneration
target_audio = session.query(AudioGeneration).order_by(AudioGeneration.id.desc()).first()
# Remove target audio from database
#session.query(AudioGeneration).delete()
#session.commit()
#session.close()
if target_audio is None:
raise ValueError("No audio found in the database")
audio_bytes = target_audio.audio
return Response(content=audio_bytes, media_type="audio/wav")
# speech to speech from arabic to english processes
#@app.post("/ar_speech_to_en_text/")
def ar_speech_to_ar_text_process(segment):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=True,
torch_dtype=torch_dtype,
device=device,
)
result = pipe(segment,generate_kwargs={"language": "arabic"})
return result["text"]
#@app.post("/ar_translate/")
def ar_text_to_en_text_translation(text):
pipe = pipeline("translation", model="facebook/nllb-200-distilled-600M")
result=pipe(text,src_lang='Egyptain Arabic',tgt_lang='English')
return result[0]['translation_text']
"""
source => arabic speech
target => english speeech
"""
def speech_to_speech_translation_ar_en(audio_url):
session=Session()
target_text=None
split_audio_segments(audio_url)
#filtering by type
speech_segments = session.query(Audio_segment).filter(Audio_segment.type == "speech").all()
for segment in speech_segments:
audio_data = segment.audio
text = ar_speech_to_ar_text_process(audio_data)
if text:
target_text=ar_text_to_en_text_translation(text)
else:
print("speech_to_text_process function not return result. ")
if target_text is None:
print("Target text is None.")
else:
segment_id = segment.id
segment_duration = segment.end_time - segment.start_time
if segment_duration <=15:
text_to_speech(segment_id,target_text,segment.audio)
else:
audio_data=extract_15_seconds(segment.audio,segment.start_time,segment.end_time)
text_to_speech(segment_id,target_text,audio_data)
os.remove(audio_data)
construct_audio()
return JSONResponse(status_code=200, content={"status_code":"succcessfully"})
@app.get("/get_en_audio/")
async def get_en_audio(audio_url):
speech_to_speech_translation_ar_en(audio_url)
session = Session()
# Get target audio from AudioGeneration
target_audio = session.query(AudioGeneration).order_by(AudioGeneration.id.desc()).first()
# Remove target audio from database
#session.query(AudioGeneration).delete()
#session.commit()
#session.close()
if target_audio is None:
raise ValueError("No audio found in the database")
audio_bytes = target_audio.audio
return Response(content=audio_bytes, media_type="audio/wav")
@app.get("/audio_segments/")
def get_all_audio_segments():
session=Session()
segments = session.query(Audio_segment).all()
segment_dicts = []
for segment in segments:
if segment.audio is None:
raise ValueError("No audio found in the database")
audio_bytes = segment.audio
file_path = f"segments//segment{segment.id}_audio.wav"
with open(file_path, "wb") as file:
file.write(audio_bytes)
segment_dicts.append({
"id": segment.id,
"start_time": segment.start_time,
"end_time": segment.end_time,
"type": segment.type,
"audio_url":file_path
})
session.close()
return {"segments":segment_dicts}
def extract_15_seconds(audio_data, start_time, end_time):
audio_segment = AudioSegment.from_file(BytesIO(audio_data), format='wav')
start_ms = start_time * 1000
end_ms = min((start_time + 15) * 1000, end_time * 1000)
extracted_segment = audio_segment[start_ms:end_ms]
temp_wav_path = "temp.wav"
extracted_segment.export(temp_wav_path, format="wav")
return temp_wav_path