Spaces:
Build error
Build error
import os | |
import re | |
import functools | |
import requests | |
import pandas as pd | |
import plotly.express as px | |
import torch | |
import gradio as gr | |
from transformers import pipeline, WhisperProcessor | |
from pyannote.audio import Pipeline | |
from librosa import load, resample | |
from rpunct import RestorePuncts | |
from utils import split_into_sentences | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
device = 0 if torch.cuda.is_available() else -1 | |
# summarization is done over inference API | |
headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"} | |
summarization_url = ( | |
"https://api-inference.huggingface.co/models/knkarthick/MEETING_SUMMARY" | |
) | |
# There was an error related to Non-english text being detected, | |
# so this regular expression gets rid of any weird character. | |
# This might be completely unnecessary. | |
eng_pattern = r"[^\d\s\w'\.\,\?]" | |
def summarize(diarized, check): | |
""" | |
diarized: a list of tuples. Each tuple has a string to be displayed and a label for highlighting. | |
The start/end times are not highlighted [(speaker text, speaker id), (start time/end time, None)] | |
check is a list of speaker ids whose speech will get summarized | |
""" | |
if len(check) == 0: | |
return "" | |
text = "" | |
for d in diarized: | |
if len(check) == 2 and d[1] is not None: | |
text += f"\n{d[1]}: {d[0]}" | |
elif d[1] in check: | |
text += f"\n{d[0]}" | |
# inner function cached because outer function cannot be cached | |
def call_summarize_api(text): | |
payload = { | |
"inputs": text, | |
"options": { | |
"use_gpu": False, | |
"wait_for_model": True, | |
}, | |
} | |
response = requests.post(summarization_url, headers=headers, json=payload) | |
return response.json()[0]["summary_text"] | |
return call_summarize_api(text) | |
# Audio components | |
asr_model = "openai/whisper-large" | |
processor = WhisperProcessor.from_pretrained(asr_model) | |
asr = pipeline( | |
"automatic-speech-recognition", | |
model=asr_model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
decoder=processor.decoder, | |
device=device, | |
) | |
speaker_segmentation = Pipeline.from_pretrained("pyannote/speaker-segmentation") | |
rpunct = RestorePuncts() | |
# Text components | |
emotion_pipeline = pipeline( | |
"text-classification", | |
model="bhadresh-savani/distilbert-base-uncased-emotion", | |
device=device, | |
) | |
EXAMPLES = [["example_audio.wav"], ["Customer_Support_Call.wav"]] | |
# display if the sentiment value is above these thresholds | |
thresholds = { | |
"joy": 0.99, | |
"anger": 0.95, | |
"surprise": 0.95, | |
"sadness": 0.98, | |
"fear": 0.95, | |
"love": 0.99, | |
} | |
def speech_to_text(speech): | |
speaker_output = speaker_segmentation(speech) | |
speech, sampling_rate = load(speech) | |
if sampling_rate != 16000: | |
speech = resample(speech, sampling_rate, 16000) | |
text = asr(speech, return_timestamps="word") | |
chunks = text["chunks"] | |
diarized_output = [] | |
i = 0 | |
speaker_counter = 0 | |
# New iteration every time the speaker changes | |
for turn, _, _ in speaker_output.itertracks(yield_label=True): | |
speaker = "Customer" if speaker_counter % 2 == 0 else "Support" | |
diarized = "" | |
while i < len(chunks) and chunks[i]["timestamp"][1] <= turn.end: | |
diarized += chunks[i]["text"].lower() + " " | |
i += 1 | |
if diarized != "": | |
diarized = rpunct.punctuate(re.sub(eng_pattern, "", diarized), lang="en") | |
diarized_output.extend( | |
[ | |
(diarized, speaker), | |
("from {:.2f}-{:.2f}".format(turn.start, turn.end), None), | |
] | |
) | |
speaker_counter += 1 | |
return diarized_output | |
def sentiment(diarized): | |
""" | |
diarized: a list of tuples. Each tuple has a string to be displayed and a label for highlighting. | |
The start/end times are not highlighted [(speaker text, speaker id), (start time/end time, None)] | |
This function gets the customer's sentiment and returns a list for highlighted text as well | |
as a plot of sentiment over time. | |
""" | |
customer_sentiments = [] | |
to_plot = [] | |
plot_sentences = [] | |
# used to set the x range of ticks on the plot | |
x_min = 100 | |
x_max = 0 | |
for i in range(0, len(diarized), 2): | |
speaker_speech, speaker_id = diarized[i] | |
times, _ = diarized[i + 1] | |
sentences = split_into_sentences(speaker_speech) | |
start_time, end_time = times[5:].split("-") | |
start_time, end_time = float(start_time), float(end_time) | |
interval_size = (end_time - start_time) / len(sentences) | |
if "Customer" in speaker_id: | |
outputs = emotion_pipeline(sentences) | |
for idx, (o, t) in enumerate(zip(outputs, sentences)): | |
sent = "neutral" | |
if o["score"] > thresholds[o["label"]]: | |
customer_sentiments.append( | |
(t + f"({round(idx*interval_size+start_time,1)} s)", o["label"]) | |
) | |
if o["label"] in {"joy", "love", "surprise"}: | |
sent = "positive" | |
elif o["label"] in {"sadness", "anger", "fear"}: | |
sent = "negative" | |
if sent != "neutral": | |
to_plot.append((start_time + idx * interval_size, sent)) | |
plot_sentences.append(t) | |
if start_time < x_min: | |
x_min = start_time | |
if end_time > x_max: | |
x_max = end_time | |
x_min -= 5 | |
x_max += 5 | |
x, y = list(zip(*to_plot)) | |
plot_df = pd.DataFrame( | |
data={ | |
"x": x, | |
"y": y, | |
"sentence": plot_sentences, | |
} | |
) | |
fig = px.line( | |
plot_df, | |
x="x", | |
y="y", | |
hover_data={ | |
"sentence": True, | |
"x": True, | |
"y": False, | |
}, | |
labels={"x": "time (seconds)", "y": "sentiment"}, | |
title=f"Customer sentiment over time", | |
) | |
fig = fig.update_yaxes(categoryorder="category ascending") | |
fig = fig.update_layout( | |
font=dict( | |
size=18, | |
), | |
xaxis_range=[x_min, x_max], | |
) | |
return customer_sentiments, fig | |
demo = gr.Blocks(enable_queue=True) | |
demo.encrypt = False | |
# for highlighting purposes | |
color_map = { | |
"joy": "green", | |
"anger": "red", | |
"surprise": "yellow", | |
"sadness": "blue", | |
"fear": "orange", | |
"love": "purple", | |
} | |
with demo: | |
with gr.Row(): | |
with gr.Column(): | |
audio = gr.Audio(label="Audio file", type="filepath") | |
with gr.Row(): | |
btn = gr.Button("Transcribe") | |
with gr.Row(): | |
examples = gr.components.Dataset( | |
components=[audio], samples=EXAMPLES, type="index" | |
) | |
with gr.Column(): | |
gr.Markdown("**Call Transcript:**") | |
diarized = gr.HighlightedText(label="Call Transcript") | |
gr.Markdown("Choose speaker to summarize:") | |
check = gr.CheckboxGroup( | |
choices=["Customer", "Support"], show_label=False, type="value" | |
) | |
summary = gr.Textbox(lines=4) | |
sentiment_btn = gr.Button("Get Customer Sentiment") | |
analyzed = gr.HighlightedText(color_map=color_map) | |
plot = gr.Plot(label="Sentiment over time", type="plotly") | |
# when example button is clicked, convert audio file to text and diarize | |
btn.click( | |
speech_to_text, | |
audio, | |
[diarized], | |
status_tracker=gr.StatusTracker(cover_container=True), | |
) | |
# when summarize checkboxes are changed, create summary | |
check.change(summarize, [diarized, check], summary) | |
# when sentiment button clicked, display highlighted text and plot | |
sentiment_btn.click(sentiment, [diarized], [analyzed, plot]) | |
def cache_example(example): | |
processed_examples = audio.preprocess_example(example) | |
diarized_output = speech_to_text(example) | |
return processed_examples, diarized_output | |
cache = [cache_example(e[0]) for e in EXAMPLES] | |
def load_example(example_id): | |
return cache[example_id] | |
examples._click_no_postprocess( | |
load_example, inputs=[examples], outputs=[audio, diarized], queue=False | |
) | |
demo.launch(debug=1) | |