ASR_arena / app.py
chinmayc3's picture
updated ui and file saving logic
695d9ae
raw
history blame
27 kB
import streamlit as st
import os
import re
import tempfile
from audio_recorder_streamlit import audio_recorder
import numpy as np
import time
import requests
import io
import base64
import random
import librosa
import fsspec
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import boto3
fs = fsspec.filesystem(
's3',
key=os.getenv("AWS_ACCESS_KEY"),
secret=os.getenv("AWS_SECRET_KEY")
)
s3_client = boto3.client(
's3',
aws_access_key_id=os.getenv("AWS_ACCESS_KEY"),
aws_secret_access_key=os.getenv("AWS_SECRET_KEY")
)
SAVE_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('RESULTS_KEY')}"
EMAIL_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('EMAILS_KEY')}"
TEMP_DIR = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('AUDIOS_KEY')}"
CREATE_TASK_URL = os.getenv("CREATE_TASK_URL")
def write_email(email):
if fs.exists(EMAIL_PATH):
with fs.open(EMAIL_PATH, 'rb') as f:
existing_content = f.read().decode('utf-8')
else:
existing_content = ''
new_content = existing_content + email + '\n'
with fs.open(EMAIL_PATH, 'wb') as f:
f.write(new_content.encode('utf-8'))
class ResultWriter:
def __init__(self, save_path):
self.save_path = save_path
self.headers = [
'email',
'path',
'Ori Apex_score', 'Ori Apex XT_score', 'deepgram_score', 'Ori Swift_score', 'Ori Prime_score',
'Ori Apex_appearance', 'Ori Apex XT_appearance', 'deepgram_appearance', 'Ori Swift_appearance', 'Ori Prime_appearance',
'Ori Apex_duration', 'Ori Apex XT_duration', 'deepgram_duration', 'Ori Swift_duration', 'Ori Prime_duration','azure_score','azure_appearance','azure_duration'
]
if not fs.exists(save_path):
print("CSV File not found in s3 bucket creating a new one",save_path)
with fs.open(save_path, 'wb') as f:
df = pd.DataFrame(columns=self.headers)
df.to_csv(f, index=False)
def write_result(self,user_email ,audio_path,option_1_duration_info,option_2_duration_info ,winner_model=None, loser_model=None, both_preferred=False, none_preferred=False):
result = {
'email': user_email,
'path': audio_path,
'Ori Apex_score': 0, 'Ori Apex XT_score': 0, 'deepgram_score': 0, 'Ori Swift_score': 0, 'Ori Prime_score': 0,
'Ori Apex_appearance': 0, 'Ori Apex XT_appearance': 0, 'deepgram_appearance': 0, 'Ori Swift_appearance': 0, 'Ori Prime_appearance': 0,
'Ori Apex_duration':0, 'Ori Apex XT_duration':0, 'deepgram_duration':0, 'Ori Swift_duration':0, 'Ori Prime_duration':0,'azure_score':0,'azure_appearance':0,'azure_duration':0
}
if winner_model:
result[f'{winner_model}_appearance'] = 1
if loser_model:
result[f'{loser_model}_appearance'] = 1
if both_preferred:
if winner_model:
result[f'{winner_model}_score'] = 1
if loser_model:
result[f'{loser_model}_score'] = 1
elif not none_preferred and winner_model:
result[f'{winner_model}_score'] = 1
if option_1_duration_info and option_1_duration_info[0]:
duration_key, duration_value = option_1_duration_info[0]
if duration_key in self.headers:
result[duration_key] = float(duration_value)
if option_2_duration_info and option_2_duration_info[0]:
duration_key, duration_value = option_2_duration_info[0]
if duration_key in self.headers:
result[duration_key] = float(duration_value)
self.write_to_s3(result)
def write_to_s3(self,result):
with fs.open(self.save_path, 'rb') as f:
df = pd.read_csv(f)
records = df.to_dict('records')
records.append(result)
df = pd.DataFrame(records)
with fs.open(self.save_path, 'wb') as f:
df.to_csv(f, index=False)
def decode_audio_array(base64_string):
bytes_data = base64.b64decode(base64_string)
buffer = io.BytesIO(bytes_data)
audio_array = np.load(buffer)
return audio_array
def send_task(payload):
header = {
"Authorization": f"Bearer {os.getenv('CREATE_TASK_API_KEY')}"
}
response = requests.post(CREATE_TASK_URL,json=payload,headers=header)
try:
response = response.json()
except Exception:
return "error please try again"
if payload["task"] == "transcribe_with_fastapi":
return response["text"]
elif payload["task"] == "fetch_audio":
array = response["array"]
array = decode_audio_array(array)
sampling_rate = response["sample_rate"]
filepath = response["filepath"]
return array,sampling_rate,filepath
def encode_audio_array(audio_array):
buffer = io.BytesIO()
np.save(buffer, audio_array)
buffer.seek(0)
base64_bytes = base64.b64encode(buffer.read())
base64_string = base64_bytes.decode('utf-8')
return base64_string
def call_function(model_name):
if st.session_state.current_audio_type == "recorded":
y,_ = librosa.load(st.session_state.audio_path,sr=22050,mono=True)
encoded_array = encode_audio_array(y)
payload = {
"task":"transcribe_with_fastapi",
"payload":{
"file_path":encoded_array,
"model_name":model_name,
"audio_b64":True
}}
else:
payload = {
"task":"transcribe_with_fastapi",
"payload":{
"file_path":st.session_state.audio_path,
"model_name":model_name,
"audio_b64":False
}}
transcript = send_task(payload)
return transcript
def transcribe_audio():
models_list = ["Ori Apex", "Ori Apex XT", "deepgram", "Ori Swift", "Ori Prime","azure"]
model1_name, model2_name = random.sample(models_list, 2)
st.session_state.option_1_model_name = model1_name
st.session_state.option_2_model_name = model2_name
time_1 = time.time()
transcript1 = call_function(model1_name)
time_2 = time.time()
transcript2 = call_function(model2_name)
time_3 = time.time()
st.session_state.option_2_response_time = round(time_3 - time_2,3)
st.session_state.option_1_response_time = round(time_2 - time_1,3)
return transcript1, transcript2
def reset_state():
st.session_state.audio = None
st.session_state.current_audio_type = None
st.session_state.audio_path = ""
st.session_state.option_selected = False
st.session_state.transcribed = False
st.session_state.option_2_model_name = ""
st.session_state.option_1_model_name = ""
st.session_state.option_1 = ""
st.session_state.option_2 = ""
st.session_state.option_1_model_name_state = ""
st.session_state.option_2_model_name_state = ""
def on_option_1_click():
if st.session_state.transcribed and not st.session_state.option_selected:
st.session_state.option_1_model_name_state = f"πŸ‘‘ {st.session_state.option_1_model_name} πŸ‘‘"
st.session_state.option_2_model_name_state = f"πŸ‘Ž {st.session_state.option_2_model_name} πŸ‘Ž"
st.session_state.choice = f"You chose Option 1. Option 1 was {st.session_state.option_1_model_name} Option 2 was {st.session_state.option_2_model_name}"
result_writer.write_result(
st.session_state.user_email,
st.session_state.audio_path,
winner_model=st.session_state.option_1_model_name,
loser_model=st.session_state.option_2_model_name,
option_1_duration_info=[(f"{st.session_state.option_1_model_name}_duration",st.session_state.option_1_response_time)],
option_2_duration_info=[(f"{st.session_state.option_2_model_name}_duration",st.session_state.option_2_response_time)]
)
st.session_state.option_selected = True
def on_option_2_click():
if st.session_state.transcribed and not st.session_state.option_selected:
st.session_state.option_2_model_name_state = f"πŸ‘‘ {st.session_state.option_2_model_name} πŸ‘‘"
st.session_state.option_1_model_name_state = f"πŸ‘Ž {st.session_state.option_1_model_name} πŸ‘Ž"
st.session_state.choice = f"You chose Option 2. Option 1 was {st.session_state.option_1_model_name} Option 2 was {st.session_state.option_2_model_name}"
result_writer.write_result(
st.session_state.user_email,
st.session_state.audio_path,
winner_model=st.session_state.option_2_model_name,
loser_model=st.session_state.option_1_model_name,
option_1_duration_info=[(f"{st.session_state.option_1_model_name}_duration",st.session_state.option_1_response_time)],
option_2_duration_info=[(f"{st.session_state.option_2_model_name}_duration",st.session_state.option_2_response_time)]
)
st.session_state.option_selected = True
def on_option_both_click():
if st.session_state.transcribed and not st.session_state.option_selected:
st.session_state.option_2_model_name_state = f"πŸ‘‘ {st.session_state.option_2_model_name} πŸ‘‘"
st.session_state.option_1_model_name_state = f"πŸ‘‘ {st.session_state.option_1_model_name} πŸ‘‘"
st.session_state.choice = f"You chose Prefer both. Option 1 was {st.session_state.option_1_model_name} Option 2 was {st.session_state.option_2_model_name}"
result_writer.write_result(
st.session_state.user_email,
st.session_state.audio_path,
winner_model=st.session_state.option_1_model_name,
loser_model=st.session_state.option_2_model_name,
option_1_duration_info=[(f"{st.session_state.option_1_model_name}_duration",st.session_state.option_1_response_time)],
option_2_duration_info=[(f"{st.session_state.option_2_model_name}_duration",st.session_state.option_2_response_time)],
both_preferred=True
)
st.session_state.option_selected = True
def on_option_none_click():
if st.session_state.transcribed and not st.session_state.option_selected:
st.session_state.option_1_model_name_state = f"πŸ‘Ž {st.session_state.option_1_model_name} πŸ‘Ž"
st.session_state.option_2_model_name_state = f"πŸ‘Ž {st.session_state.option_2_model_name} πŸ‘Ž"
st.session_state.choice = f"You chose none option. Option 1 was {st.session_state.option_1_model_name} Option 2 was {st.session_state.option_2_model_name}"
result_writer.write_result(
st.session_state.user_email,
st.session_state.audio_path,
winner_model=st.session_state.option_1_model_name,
loser_model=st.session_state.option_2_model_name,
option_1_duration_info=[(f"{st.session_state.option_1_model_name}_duration",st.session_state.option_1_response_time)],
option_2_duration_info=[(f"{st.session_state.option_2_model_name}_duration",st.session_state.option_2_response_time)],
none_preferred=True
)
st.session_state.option_selected = True
def on_click_transcribe():
if st.session_state.has_audio:
option_1_text, option_2_text = transcribe_audio(
)
st.session_state.option_1 = option_1_text
st.session_state.option_2 = option_2_text
st.session_state.transcribed = True
st.session_state.option_1_model_name_state = ""
st.session_state.option_2_model_name_state = ""
def on_random_click():
reset_state()
fetch_audio_payload = {"task": "fetch_audio"}
array, sampling_rate, filepath = send_task(fetch_audio_payload)
st.session_state.audio = {"data":array,"sample_rate":sampling_rate,"format":"audio/wav"}
st.session_state.has_audio = True
st.session_state.current_audio_type = "random"
st.session_state.audio_path = filepath
st.session_state.option_selected = None
result_writer = ResultWriter(SAVE_PATH)
def validate_email(email):
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
return re.match(pattern, email) is not None
def get_model_abbreviation(model_name):
abbrev_map = {
'Ori Apex': 'Ori Apex',
'Ori Apex XT': 'Ori Apex XT',
'deepgram': 'DG',
'Ori Swift': 'Ori Swift',
'Ori Prime': 'Ori Prime',
'azure' : 'Azure'
}
return abbrev_map.get(model_name, model_name)
def calculate_metrics(df):
models = ['Ori Apex', 'Ori Apex XT', 'deepgram', 'Ori Swift', 'Ori Prime', 'azure']
metrics = {}
for model in models:
appearances = df[f'{model}_appearance'].sum()
wins = df[f'{model}_score'].sum()
durations = df[df[f'{model}_appearance'] == 1][f'{model}_duration']
if appearances > 0:
win_rate = (wins / appearances) * 100
avg_duration = durations.mean()
duration_std = durations.std()
else:
win_rate = 0
avg_duration = 0
duration_std = 0
metrics[model] = {
'appearances': appearances,
'wins': wins,
'win_rate': win_rate,
'avg_response_time': avg_duration,
'response_time_std': duration_std
}
return metrics
def create_win_rate_chart(metrics):
models = list(metrics.keys())
win_rates = [metrics[model]['win_rate'] for model in models]
fig = go.Figure(data=[
go.Bar(
x=[get_model_abbreviation(model) for model in models],
y=win_rates,
text=[f'{rate:.1f}%' for rate in win_rates],
textposition='auto',
hovertext=models
)
])
fig.update_layout(
title='Win Rate by Model',
xaxis_title='Model',
yaxis_title='Win Rate (%)',
yaxis_range=[0, 100]
)
return fig
def create_appearance_chart(metrics):
models = list(metrics.keys())
appearances = [metrics[model]['appearances'] for model in models]
fig = px.pie(
values=appearances,
names=[get_model_abbreviation(model) for model in models],
title='Model Appearances Distribution',
hover_data=[models]
)
return fig
def create_head_to_head_matrix(df):
models = ['Ori Apex', 'Ori Apex XT', 'deepgram', 'Ori Swift', 'Ori Prime', 'azure']
matrix = np.zeros((len(models), len(models)))
for i, model1 in enumerate(models):
for j, model2 in enumerate(models):
if i != j:
matches = df[
(df[f'{model1}_appearance'] == 1) &
(df[f'{model2}_appearance'] == 1)
]
if len(matches) > 0:
win_rate = (matches[f'{model1}_score'].sum() / len(matches)) * 100
matrix[i][j] = win_rate
fig = go.Figure(data=go.Heatmap(
z=matrix,
x=[get_model_abbreviation(model) for model in models],
y=[get_model_abbreviation(model) for model in models],
text=[[f'{val:.1f}%' if val > 0 else '' for val in row] for row in matrix],
texttemplate='%{text}',
colorscale='RdYlBu',
zmin=0,
zmax=100
))
fig.update_layout(
title='Head-to-Head Win Rates',
xaxis_title='Opponent Model',
yaxis_title='Model'
)
return fig
def create_metric_container(label, value, full_name=None):
container = st.container()
with container:
st.markdown(f"**{label}**")
if full_name:
st.markdown(f"<h3 style='margin-top: 0;'>{value}</h3>", unsafe_allow_html=True)
st.caption(f"Full name: {full_name}")
else:
st.markdown(f"<h3 style='margin-top: 0;'>{value}</h3>", unsafe_allow_html=True)
def on_refresh_click():
with fs.open(SAVE_PATH, 'rb') as f:
st.session_state.df = pd.read_csv(f)
def dashboard():
st.title('Model Arena Scoreboard')
if "df" not in st.session_state:
with fs.open(SAVE_PATH, 'rb') as f:
st.session_state.df = pd.read_csv(f)
st.button("Refresh",on_click=on_refresh_click)
if len(st.session_state.df) != 0:
metrics = calculate_metrics(st.session_state.df)
MODEL_DESCRIPTIONS = {
"Ori Prime": "Foundational, large, and stable.",
"Ori Swift": "Lighter and faster than Ori Prime.",
"Ori Apex": "The top-performing model, fast and stable.",
"Ori Apex XT": "Enhanced with more training, though slightly less stable than Ori Apex.",
"DG" : "Deepgram Nova-2 API",
"Azure" : "Azure Speech Services API"
}
st.header('Model Descriptions')
cols = st.columns(2)
for idx, (model, description) in enumerate(MODEL_DESCRIPTIONS.items()):
with cols[idx % 2]:
st.markdown(f"""
<div style='padding: 1rem; border: 1px solid #e1e4e8; border-radius: 6px; margin-bottom: 1rem;'>
<h3 style='margin: 0; margin-bottom: 0.5rem;'>{model}</h3>
<p style='margin: 0; color: #6e7681;'>{description}</p>
</div>
""", unsafe_allow_html=True)
st.header('Overall Performance')
col1, col2, col3= st.columns(3)
with col1:
create_metric_container("Total Matches", len(st.session_state.df))
best_model = max(metrics.items(), key=lambda x: x[1]['win_rate'])[0]
with col2:
create_metric_container(
"Best Model",
get_model_abbreviation(best_model),
full_name=best_model
)
most_appearances = max(metrics.items(), key=lambda x: x[1]['appearances'])[0]
with col3:
create_metric_container(
"Most Used",
get_model_abbreviation(most_appearances),
full_name=most_appearances
)
metrics_df = pd.DataFrame.from_dict(metrics, orient='index')
metrics_df['win_rate'] = metrics_df['win_rate'].round(2)
metrics_df.drop(["avg_response_time","response_time_std"],axis=1,inplace=True)
metrics_df.index = [get_model_abbreviation(model) for model in metrics_df.index]
st.dataframe(metrics_df,use_container_width=True)
st.header('Win Rates')
win_rate_chart = create_win_rate_chart(metrics)
st.plotly_chart(win_rate_chart, use_container_width=True)
st.header('Appearance Distribution')
appearance_chart = create_appearance_chart(metrics)
st.plotly_chart(appearance_chart, use_container_width=True)
st.header('Head-to-Head Analysis')
matrix_chart = create_head_to_head_matrix(st.session_state.df)
st.plotly_chart(matrix_chart, use_container_width=True)
st.header('Full Dataframe')
st.dataframe(st.session_state.df.drop(['path','Ori Apex_duration', 'Ori Apex XT_duration', 'deepgram_duration', 'Ori Swift_duration', 'Ori Prime_duration','azure_duration','email'],axis=1),use_container_width=True)
else:
st.write("No Data to show")
def about():
st.title("About")
st.markdown(
"""
# Ori Speech-To-Text Arena
"""
)
st.markdown(
"""## Arena
"""
)
st.markdown(
"""
* The Arena allows a user to record their audios, in which speech will be recognized by two randomly selected models. After listening to the audio, and evaluating the output from both the models, the user can vote on which transcription they prefer. Due to the risks of human bias and abuse, model names are revealed only after a vote is submitted."""
)
st.markdown(
"## Scoreboard"
)
st.markdown(
""" * The Scoreboard shows the performance of the models in the Arena. The user can see the overall performance of the models, the model with the highest win rate, and the model with the most appearances. The user can also see the win rates of each model, as well as the appearance distribution of each model."""
)
st.markdown(
"## Contact Us"
)
st.markdown(
"To inquire about our speech-to-text models and APIs, you can submit your email using the form below."
)
with st.form("login_form"):
st.subheader("Please Enter you Email")
email = st.text_input("Email")
submit_button = st.form_submit_button("Submit")
if submit_button:
if not email:
st.error("Please fill in all fields")
else:
if not validate_email(email):
st.error("Please enter a valid email address")
else:
st.session_state.logged_in = True
st.session_state.user_email = email
write_email(st.session_state.user_email)
st.success("Thanks for submitting your email, our team will be in touch with you shortly!")
def main():
st.title("βš”οΈ Ori Speech-To-Text Arena βš”οΈ")
if "has_audio" not in st.session_state:
st.session_state.has_audio = False
if "audio" not in st.session_state:
st.session_state.audio = None
if "audio_path" not in st.session_state:
st.session_state.audio_path = ""
if "option_1" not in st.session_state:
st.session_state.option_1 = ""
if "option_2" not in st.session_state:
st.session_state.option_2 = ""
if "transcribed" not in st.session_state:
st.session_state.transcribed = False
if "option_1_model_name_state" not in st.session_state:
st.session_state.option_1_model_name_state = ""
if "option_1_model_name" not in st.session_state:
st.session_state.option_1_model_name = ""
if "option_2_model_name" not in st.session_state:
st.session_state.option_2_model_name = ""
if "option_2_model_name_state" not in st.session_state:
st.session_state.option_2_model_name_state = ""
if "user_email" not in st.session_state:
st.session_state.user_email = ""
if 'logged_in' not in st.session_state:
st.session_state.logged_in = False
arena, scoreboard,about_tab = st.tabs(["Arena", "Scoreboard","About"])
with arena:
INSTR = """
## Instructions:
* Record audio to recognise speech (or press 🎲 for random Audio).
* Click on transcribe audio button to commence the transcription process.
* Read the two options one after the other while listening to the audio.
* Vote on which transcript you prefer.
* Note:
* Model names are revealed after the vote is cast.
* Currently only Indian Hindi language is supported, and
the results will be in Hinglish (Hindi in Latin script)
* Random audios are only in hindi
* It may take up to 30 seconds for speech recognition in some cases.
""".strip()
st.markdown(INSTR)
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("### Record Audio")
with st.container():
audio_bytes = audio_recorder(
text="πŸŽ™οΈ Click to Record",
pause_threshold=3,
icon_size="2x",
key="audio_recorder",
sample_rate=16_000
)
if audio_bytes and audio_bytes != st.session_state.get('last_recorded_audio'):
reset_state()
st.session_state.last_recorded_audio = audio_bytes
st.session_state.audio = {"data":audio_bytes,"format":"audio/wav"}
st.session_state.current_audio_type = "recorded"
st.session_state.has_audio = True
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
tmp_file.write(audio_bytes)
os.makedirs(TEMP_DIR, exist_ok=True)
s3_client.put_object(Bucket=os.getenv('AWS_BUCKET_NAME'), Key=f"{os.getenv('AUDIOS_KEY')}/{tmp_file.name.split('/')[-1]}", Body=audio_bytes)
st.session_state.audio_path = tmp_file.name
st.session_state.option_selected = None
with col2:
st.markdown("### Random Audio Example")
with st.container():
st.button("🎲 Random Audio",on_click=on_random_click)
if st.session_state.has_audio:
st.audio(**st.session_state.audio)
with st.container():
st.button("πŸ“ Transcribe Audio",on_click=on_click_transcribe,use_container_width=True)
text_containers = st.columns([1, 1])
name_containers = st.columns([1, 1])
with text_containers[0]:
st.text_area("Option 1", value=st.session_state.option_1, height=300)
with text_containers[1]:
st.text_area("Option 2", value=st.session_state.option_2, height=300)
with name_containers[0]:
if st.session_state.option_1_model_name_state:
st.markdown(f"<div style='text-align: center'>{st.session_state.option_1_model_name_state}</div>", unsafe_allow_html=True)
with name_containers[1]:
if st.session_state.option_2_model_name_state:
st.markdown(f"<div style='text-align: center'>{st.session_state.option_2_model_name_state}</div>", unsafe_allow_html=True)
c1, c2, c3, c4 = st.columns(4)
with c1:
st.button("Prefer Option 1",on_click=on_option_1_click)
with c2:
st.button("Prefer Option 2",on_click=on_option_2_click)
with c3:
st.button("Prefer Both",on_click=on_option_both_click)
with c4:
st.button("Prefer None",on_click=on_option_none_click)
with scoreboard:
if st.session_state.logged_in:
dashboard()
else:
with st.form("contact_us_form"):
st.subheader("Please Enter you Email")
email = st.text_input("Email")
submit_button = st.form_submit_button("Submit")
if submit_button:
if not email:
st.error("Please fill in all fields")
else:
if not validate_email(email):
st.error("Please enter a valid email address")
else:
st.session_state.logged_in = True
st.session_state.user_email = email
write_email(st.session_state.user_email)
st.success("Thanks for submitting your email")
if st.session_state.logged_in:
dashboard()
with about_tab:
about()
main()