terry-li-hm
Update
09b9a40
raw
history blame
11.9 kB
import datetime
import math
import os
import numpy as np
import torch
import torchaudio
from funasr import AutoModel
from pyannote.audio import Audio, Pipeline
from pyannote.core import Segment
# Load models
model = AutoModel(
model="FunAudioLLM/SenseVoiceSmall",
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# vad_kwargs={"max_single_segment_time": 30000},
hub="hf",
device="cuda" if torch.cuda.is_available() else "cpu",
)
pyannote_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_TOKEN")
)
if torch.cuda.is_available():
pyannote_pipeline.to(torch.device("cuda"))
# Emoji dictionaries and formatting functions
emo_dict = {
"<|HAPPY|>": "๐Ÿ˜Š",
"<|SAD|>": "๐Ÿ˜”",
"<|ANGRY|>": "๐Ÿ˜ก",
"<|NEUTRAL|>": "",
"<|FEARFUL|>": "๐Ÿ˜ฐ",
"<|DISGUSTED|>": "๐Ÿคข",
"<|SURPRISED|>": "๐Ÿ˜ฎ",
}
event_dict = {
"<|BGM|>": "๐ŸŽผ",
"<|Speech|>": "",
"<|Applause|>": "๐Ÿ‘",
"<|Laughter|>": "๐Ÿ˜€",
"<|Cry|>": "๐Ÿ˜ญ",
"<|Sneeze|>": "๐Ÿคง",
"<|Breath|>": "",
"<|Cough|>": "๐Ÿคง",
}
emoji_dict = {
"<|nospeech|><|Event_UNK|>": "โ“",
"<|zh|>": "",
"<|en|>": "",
"<|yue|>": "",
"<|ja|>": "",
"<|ko|>": "",
"<|nospeech|>": "",
"<|HAPPY|>": "๐Ÿ˜Š",
"<|SAD|>": "๐Ÿ˜”",
"<|ANGRY|>": "๐Ÿ˜ก",
"<|NEUTRAL|>": "",
"<|BGM|>": "๐ŸŽผ",
"<|Speech|>": "",
"<|Applause|>": "๐Ÿ‘",
"<|Laughter|>": "๐Ÿ˜€",
"<|FEARFUL|>": "๐Ÿ˜ฐ",
"<|DISGUSTED|>": "๐Ÿคข",
"<|SURPRISED|>": "๐Ÿ˜ฎ",
"<|Cry|>": "๐Ÿ˜ญ",
"<|EMO_UNKNOWN|>": "",
"<|Sneeze|>": "๐Ÿคง",
"<|Breath|>": "",
"<|Cough|>": "๐Ÿ˜ท",
"<|Sing|>": "",
"<|Speech_Noise|>": "",
"<|withitn|>": "",
"<|woitn|>": "",
"<|GBG|>": "",
"<|Event_UNK|>": "",
}
lang_dict = {
"<|zh|>": "<|lang|>",
"<|en|>": "<|lang|>",
"<|yue|>": "<|lang|>",
"<|ja|>": "<|lang|>",
"<|ko|>": "<|lang|>",
"<|nospeech|>": "<|lang|>",
}
emo_set = {"๐Ÿ˜Š", "๐Ÿ˜”", "๐Ÿ˜ก", "๐Ÿ˜ฐ", "๐Ÿคข", "๐Ÿ˜ฎ"}
event_set = {"๐ŸŽผ", "๐Ÿ‘", "๐Ÿ˜€", "๐Ÿ˜ญ", "๐Ÿคง", "๐Ÿ˜ท"}
def clean_and_emoji_annotate_speech(text):
# Helper function to get the first emoji from a string that belongs to a given set
def get_emoji(s, emoji_set):
return next((char for char in s if char in emoji_set), None)
# Helper function to format text with emojis based on special tokens
def format_text_with_emojis(s):
# Count occurrences of special tokens
sptk_dict = {sptk: s.count(sptk) for sptk in emoji_dict}
# Remove all special tokens from the text
for sptk in emoji_dict:
s = s.replace(sptk, "")
# Determine the dominant emotion
emo = "<|NEUTRAL|>"
for e in emo_dict:
if sptk_dict.get(e, 0) > sptk_dict.get(emo, 0):
emo = e
# Add event emojis at the beginning and emotion emoji at the end
s = (
"".join(event_dict[e] for e in event_dict if sptk_dict.get(e, 0) > 0)
+ s
+ emo_dict[emo]
)
# Remove spaces around emojis
for emoji in emo_set.union(event_set):
s = s.replace(f" {emoji}", emoji).replace(f"{emoji} ", emoji)
return s.strip()
# Replace special tags and language markers
text = text.replace("<|nospeech|><|Event_UNK|>", "โ“")
for lang, replacement in lang_dict.items():
text = text.replace(lang, replacement)
# Process each language segment
segments = [
format_text_with_emojis(segment.strip()) for segment in text.split("<|lang|>")
]
formatted_segments = []
prev_event = prev_emotion = None
# Combine segments, avoiding duplicate emojis
for segment in segments:
if not segment:
continue
current_event = get_emoji(segment, event_set)
current_emotion = get_emoji(segment, emo_set)
# Remove leading event emoji if it's the same as the previous one
if current_event is not None:
segment = segment[1:] if segment.startswith(current_event) else segment
# Move emotion emoji to the end if it's different from the previous one
if current_emotion is not None and current_emotion != prev_emotion:
segment = segment.replace(current_emotion, "") + current_emotion
formatted_segments.append(segment.strip())
prev_event, prev_emotion = current_event, current_emotion
# Join segments and remove unnecessary "The." at the end
result = " ".join(formatted_segments).replace("The.", "").strip()
return result
def time_to_seconds(time_str):
h, m, s = time_str.split(":")
return round(int(h) * 3600 + int(m) * 60 + float(s), 9)
def parse_time(time_str):
# Remove 's' if present at the end of the string
time_str = time_str.rstrip("s")
# Split the time string into hours, minutes, and seconds
parts = time_str.split(":")
if len(parts) == 3:
h, m, s = parts
elif len(parts) == 2:
h = "0"
m, s = parts
else:
h = m = "0"
s = parts[0]
return int(h) * 3600 + int(m) * 60 + float(s)
def format_time(seconds, use_short_format=True, always_use_seconds=False):
if isinstance(seconds, datetime.timedelta):
seconds = seconds.total_seconds()
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(int(minutes), 60)
if always_use_seconds or (use_short_format and hours == 0 and minutes == 0):
return f"{seconds:06.3f}s"
elif use_short_format or hours == 0:
return f"{minutes:02d}:{seconds:06.3f}"
else:
return f"{hours:02d}:{minutes:02d}:{seconds:06.3f}"
def generate_diarization(audio_path):
# Get the Hugging Face token from the environment variable
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise ValueError(
"HF_TOKEN environment variable is not set. Please set it with your Hugging Face token."
)
# Initialize the audio processor
audio = Audio(sample_rate=16000, mono=True)
# Load the pretrained pipeline
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1", use_auth_token=hf_token
)
# Send pipeline to GPU if available
if torch.cuda.is_available():
pipeline.to(torch.device("cuda"))
# Set the correct path for the audio file
script_dir = os.path.dirname(os.path.abspath(__file__))
possible_paths = [
os.path.join(script_dir, "example", "mtr.mp3"),
os.path.join(script_dir, "..", "example", "mtr.mp3"),
os.path.join(script_dir, "mtr.mp3"),
"mtr.mp3",
audio_path, # Add the provided audio_path to the list of possible paths
]
file_path = None
for path in possible_paths:
if os.path.exists(path):
file_path = path
break
if file_path is None:
print("Debugging information:")
print(f"Current working directory: {os.getcwd()}")
print(f"Script directory: {script_dir}")
print("Attempted paths:")
for path in possible_paths:
print(f" {path}")
raise FileNotFoundError(
"Could not find the audio file. Please ensure it's in the correct location."
)
print(f"Using audio file: {file_path}")
# Process the audio file
waveform, sample_rate = audio(file_path)
# Create a dictionary with the audio information
file = {"waveform": waveform, "sample_rate": sample_rate, "uri": "mtr"}
# Run the diarization
output = pipeline(file)
# Save results in human-readable format
diarization_segments = []
txt_file = "mtr_dn.txt"
with open(txt_file, "w") as f:
for turn, _, speaker in output.itertracks(yield_label=True):
start_time = format_time(turn.start)
end_time = format_time(turn.end)
duration = format_time(turn.end - turn.start)
line = f"{start_time} - {end_time} ({duration}): {speaker}\n"
f.write(line)
print(line.strip())
diarization_segments.append(
(
parse_time(start_time),
parse_time(end_time),
parse_time(duration),
speaker,
)
)
print(f"\nHuman-readable diarization results saved to {txt_file}")
return diarization_segments
def process_audio(audio_path, language="yue", fs=16000):
# Generate diarization segments
diarization_segments = generate_diarization(audio_path)
# Load and preprocess audio
waveform, sample_rate = torchaudio.load(audio_path)
if sample_rate != fs:
resampler = torchaudio.transforms.Resample(sample_rate, fs)
waveform = resampler(waveform)
input_wav = waveform.mean(0).numpy()
# Determine if the audio is less than one minute
total_duration = sum(duration for _, _, duration, _ in diarization_segments)
use_short_format = total_duration < 60
# Process the audio in chunks based on diarization segments
results = []
for start_time, end_time, duration, speaker in diarization_segments:
start_seconds = start_time
end_seconds = end_time
# Convert time to sample indices
start_sample = int(start_seconds * fs)
end_sample = int(end_seconds * fs)
chunk = input_wav[start_sample:end_sample]
try:
text = model.generate(
input=chunk,
cache={},
language=language,
use_itn=True,
batch_size_s=500,
merge_vad=True,
)
text = text[0]["text"]
# Print the text before clean_and_emoji_annotate_speech
print(f"Text before clean_and_emoji_annotate_speech: {text}")
text = clean_and_emoji_annotate_speech(text)
# Handle empty transcriptions
if not text.strip():
text = "[inaudible]"
results.append((speaker, start_time, end_time, duration, text))
except AssertionError as e:
if "choose a window size" in str(e):
print(
f"Warning: Audio segment too short to process. Skipping. Error: {e}"
)
results.append((speaker, start_time, end_time, duration, "[too short]"))
else:
raise
# Format the results
formatted_text = ""
for speaker, start, end, duration, text in results:
start_str = format_time(start, always_use_seconds=True)
end_str = format_time(end, always_use_seconds=True)
duration_str = format_time(duration, always_use_seconds=True)
speaker_num = "1" if speaker == "SPEAKER_00" else "2"
line = f"{start_str} - {end_str} ({duration_str}) Speaker {speaker_num}: {text}"
formatted_text += line + "\n"
print(f"Debug: Formatted line: {line}")
print("Debug: Full formatted text:")
print(formatted_text)
return formatted_text.strip()
if __name__ == "__main__":
audio_path = "example/mtr.mp3" # Replace with your audio file path
language = "yue" # Set language to Cantonese
# Option to run only diarization
diarization_only = False # Set this to True if you want only diarization
if diarization_only:
diarization_segments = generate_diarization(audio_path)
# You can add code here to save or process the diarization results as needed
else:
result = process_audio(audio_path, language)
# Save the result to mtr.txt
output_path = "mtr.txt"
with open(output_path, "w", encoding="utf-8") as f:
f.write(result)
print(f"Diarization and transcription result has been saved to {output_path}")