Spaces:
Sleeping
Sleeping
import datetime | |
import math | |
import os | |
import numpy as np | |
import torch | |
import torchaudio | |
from funasr import AutoModel | |
from pyannote.audio import Audio, Pipeline | |
from pyannote.core import Segment | |
# Load models | |
model = AutoModel( | |
model="FunAudioLLM/SenseVoiceSmall", | |
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", | |
# vad_kwargs={"max_single_segment_time": 30000}, | |
hub="hf", | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
) | |
pyannote_pipeline = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_TOKEN") | |
) | |
if torch.cuda.is_available(): | |
pyannote_pipeline.to(torch.device("cuda")) | |
# Emoji dictionaries and formatting functions | |
emo_dict = { | |
"<|HAPPY|>": "๐", | |
"<|SAD|>": "๐", | |
"<|ANGRY|>": "๐ก", | |
"<|NEUTRAL|>": "", | |
"<|FEARFUL|>": "๐ฐ", | |
"<|DISGUSTED|>": "๐คข", | |
"<|SURPRISED|>": "๐ฎ", | |
} | |
event_dict = { | |
"<|BGM|>": "๐ผ", | |
"<|Speech|>": "", | |
"<|Applause|>": "๐", | |
"<|Laughter|>": "๐", | |
"<|Cry|>": "๐ญ", | |
"<|Sneeze|>": "๐คง", | |
"<|Breath|>": "", | |
"<|Cough|>": "๐คง", | |
} | |
emoji_dict = { | |
"<|nospeech|><|Event_UNK|>": "โ", | |
"<|zh|>": "", | |
"<|en|>": "", | |
"<|yue|>": "", | |
"<|ja|>": "", | |
"<|ko|>": "", | |
"<|nospeech|>": "", | |
"<|HAPPY|>": "๐", | |
"<|SAD|>": "๐", | |
"<|ANGRY|>": "๐ก", | |
"<|NEUTRAL|>": "", | |
"<|BGM|>": "๐ผ", | |
"<|Speech|>": "", | |
"<|Applause|>": "๐", | |
"<|Laughter|>": "๐", | |
"<|FEARFUL|>": "๐ฐ", | |
"<|DISGUSTED|>": "๐คข", | |
"<|SURPRISED|>": "๐ฎ", | |
"<|Cry|>": "๐ญ", | |
"<|EMO_UNKNOWN|>": "", | |
"<|Sneeze|>": "๐คง", | |
"<|Breath|>": "", | |
"<|Cough|>": "๐ท", | |
"<|Sing|>": "", | |
"<|Speech_Noise|>": "", | |
"<|withitn|>": "", | |
"<|woitn|>": "", | |
"<|GBG|>": "", | |
"<|Event_UNK|>": "", | |
} | |
lang_dict = { | |
"<|zh|>": "<|lang|>", | |
"<|en|>": "<|lang|>", | |
"<|yue|>": "<|lang|>", | |
"<|ja|>": "<|lang|>", | |
"<|ko|>": "<|lang|>", | |
"<|nospeech|>": "<|lang|>", | |
} | |
emo_set = {"๐", "๐", "๐ก", "๐ฐ", "๐คข", "๐ฎ"} | |
event_set = {"๐ผ", "๐", "๐", "๐ญ", "๐คง", "๐ท"} | |
def clean_and_emoji_annotate_speech(text): | |
# Helper function to get the first emoji from a string that belongs to a given set | |
def get_emoji(s, emoji_set): | |
return next((char for char in s if char in emoji_set), None) | |
# Helper function to format text with emojis based on special tokens | |
def format_text_with_emojis(s): | |
# Count occurrences of special tokens | |
sptk_dict = {sptk: s.count(sptk) for sptk in emoji_dict} | |
# Remove all special tokens from the text | |
for sptk in emoji_dict: | |
s = s.replace(sptk, "") | |
# Determine the dominant emotion | |
emo = "<|NEUTRAL|>" | |
for e in emo_dict: | |
if sptk_dict.get(e, 0) > sptk_dict.get(emo, 0): | |
emo = e | |
# Add event emojis at the beginning and emotion emoji at the end | |
s = ( | |
"".join(event_dict[e] for e in event_dict if sptk_dict.get(e, 0) > 0) | |
+ s | |
+ emo_dict[emo] | |
) | |
# Remove spaces around emojis | |
for emoji in emo_set.union(event_set): | |
s = s.replace(f" {emoji}", emoji).replace(f"{emoji} ", emoji) | |
return s.strip() | |
# Replace special tags and language markers | |
text = text.replace("<|nospeech|><|Event_UNK|>", "โ") | |
for lang, replacement in lang_dict.items(): | |
text = text.replace(lang, replacement) | |
# Process each language segment | |
segments = [ | |
format_text_with_emojis(segment.strip()) for segment in text.split("<|lang|>") | |
] | |
formatted_segments = [] | |
prev_event = prev_emotion = None | |
# Combine segments, avoiding duplicate emojis | |
for segment in segments: | |
if not segment: | |
continue | |
current_event = get_emoji(segment, event_set) | |
current_emotion = get_emoji(segment, emo_set) | |
# Remove leading event emoji if it's the same as the previous one | |
if current_event is not None: | |
segment = segment[1:] if segment.startswith(current_event) else segment | |
# Move emotion emoji to the end if it's different from the previous one | |
if current_emotion is not None and current_emotion != prev_emotion: | |
segment = segment.replace(current_emotion, "") + current_emotion | |
formatted_segments.append(segment.strip()) | |
prev_event, prev_emotion = current_event, current_emotion | |
# Join segments and remove unnecessary "The." at the end | |
result = " ".join(formatted_segments).replace("The.", "").strip() | |
return result | |
def time_to_seconds(time_str): | |
h, m, s = time_str.split(":") | |
return round(int(h) * 3600 + int(m) * 60 + float(s), 9) | |
def parse_time(time_str): | |
# Remove 's' if present at the end of the string | |
time_str = time_str.rstrip("s") | |
# Split the time string into hours, minutes, and seconds | |
parts = time_str.split(":") | |
if len(parts) == 3: | |
h, m, s = parts | |
elif len(parts) == 2: | |
h = "0" | |
m, s = parts | |
else: | |
h = m = "0" | |
s = parts[0] | |
return int(h) * 3600 + int(m) * 60 + float(s) | |
def format_time(seconds, use_short_format=True, always_use_seconds=False): | |
if isinstance(seconds, datetime.timedelta): | |
seconds = seconds.total_seconds() | |
minutes, seconds = divmod(seconds, 60) | |
hours, minutes = divmod(int(minutes), 60) | |
if always_use_seconds or (use_short_format and hours == 0 and minutes == 0): | |
return f"{seconds:06.3f}s" | |
elif use_short_format or hours == 0: | |
return f"{minutes:02d}:{seconds:06.3f}" | |
else: | |
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}" | |
def generate_diarization(audio_path): | |
# Get the Hugging Face token from the environment variable | |
hf_token = os.environ.get("HF_TOKEN") | |
if not hf_token: | |
raise ValueError( | |
"HF_TOKEN environment variable is not set. Please set it with your Hugging Face token." | |
) | |
# Initialize the audio processor | |
audio = Audio(sample_rate=16000, mono=True) | |
# Load the pretrained pipeline | |
pipeline = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", use_auth_token=hf_token | |
) | |
# Send pipeline to GPU if available | |
if torch.cuda.is_available(): | |
pipeline.to(torch.device("cuda")) | |
# Set the correct path for the audio file | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
possible_paths = [ | |
os.path.join(script_dir, "example", "mtr.mp3"), | |
os.path.join(script_dir, "..", "example", "mtr.mp3"), | |
os.path.join(script_dir, "mtr.mp3"), | |
"mtr.mp3", | |
audio_path, # Add the provided audio_path to the list of possible paths | |
] | |
file_path = None | |
for path in possible_paths: | |
if os.path.exists(path): | |
file_path = path | |
break | |
if file_path is None: | |
print("Debugging information:") | |
print(f"Current working directory: {os.getcwd()}") | |
print(f"Script directory: {script_dir}") | |
print("Attempted paths:") | |
for path in possible_paths: | |
print(f" {path}") | |
raise FileNotFoundError( | |
"Could not find the audio file. Please ensure it's in the correct location." | |
) | |
print(f"Using audio file: {file_path}") | |
# Process the audio file | |
waveform, sample_rate = audio(file_path) | |
# Create a dictionary with the audio information | |
file = {"waveform": waveform, "sample_rate": sample_rate, "uri": "mtr"} | |
# Run the diarization | |
output = pipeline(file) | |
# Save results in human-readable format | |
diarization_segments = [] | |
txt_file = "mtr_dn.txt" | |
with open(txt_file, "w") as f: | |
for turn, _, speaker in output.itertracks(yield_label=True): | |
start_time = format_time(turn.start) | |
end_time = format_time(turn.end) | |
duration = format_time(turn.end - turn.start) | |
line = f"{start_time} - {end_time} ({duration}): {speaker}\n" | |
f.write(line) | |
print(line.strip()) | |
diarization_segments.append( | |
( | |
parse_time(start_time), | |
parse_time(end_time), | |
parse_time(duration), | |
speaker, | |
) | |
) | |
print(f"\nHuman-readable diarization results saved to {txt_file}") | |
return diarization_segments | |
def process_audio(audio_path, language="yue", fs=16000): | |
# Generate diarization segments | |
diarization_segments = generate_diarization(audio_path) | |
# Load and preprocess audio | |
waveform, sample_rate = torchaudio.load(audio_path) | |
if sample_rate != fs: | |
resampler = torchaudio.transforms.Resample(sample_rate, fs) | |
waveform = resampler(waveform) | |
input_wav = waveform.mean(0).numpy() | |
# Determine if the audio is less than one minute | |
total_duration = sum(duration for _, _, duration, _ in diarization_segments) | |
use_short_format = total_duration < 60 | |
# Process the audio in chunks based on diarization segments | |
results = [] | |
for start_time, end_time, duration, speaker in diarization_segments: | |
start_seconds = start_time | |
end_seconds = end_time | |
# Convert time to sample indices | |
start_sample = int(start_seconds * fs) | |
end_sample = int(end_seconds * fs) | |
chunk = input_wav[start_sample:end_sample] | |
try: | |
text = model.generate( | |
input=chunk, | |
cache={}, | |
language=language, | |
use_itn=True, | |
batch_size_s=500, | |
merge_vad=True, | |
) | |
text = text[0]["text"] | |
# Print the text before clean_and_emoji_annotate_speech | |
print(f"Text before clean_and_emoji_annotate_speech: {text}") | |
text = clean_and_emoji_annotate_speech(text) | |
# Handle empty transcriptions | |
if not text.strip(): | |
text = "[inaudible]" | |
results.append((speaker, start_time, end_time, duration, text)) | |
except AssertionError as e: | |
if "choose a window size" in str(e): | |
print( | |
f"Warning: Audio segment too short to process. Skipping. Error: {e}" | |
) | |
results.append((speaker, start_time, end_time, duration, "[too short]")) | |
else: | |
raise | |
# Format the results | |
formatted_text = "" | |
for speaker, start, end, duration, text in results: | |
start_str = format_time(start, always_use_seconds=True) | |
end_str = format_time(end, always_use_seconds=True) | |
duration_str = format_time(duration, always_use_seconds=True) | |
speaker_num = "1" if speaker == "SPEAKER_00" else "2" | |
line = f"{start_str} - {end_str} ({duration_str}) Speaker {speaker_num}: {text}" | |
formatted_text += line + "\n" | |
print(f"Debug: Formatted line: {line}") | |
print("Debug: Full formatted text:") | |
print(formatted_text) | |
return formatted_text.strip() | |
if __name__ == "__main__": | |
audio_path = "example/mtr.mp3" # Replace with your audio file path | |
language = "yue" # Set language to Cantonese | |
# Option to run only diarization | |
diarization_only = False # Set this to True if you want only diarization | |
if diarization_only: | |
diarization_segments = generate_diarization(audio_path) | |
# You can add code here to save or process the diarization results as needed | |
else: | |
result = process_audio(audio_path, language) | |
# Save the result to mtr.txt | |
output_path = "mtr.txt" | |
with open(output_path, "w", encoding="utf-8") as f: | |
f.write(result) | |
print(f"Diarization and transcription result has been saved to {output_path}") | |