Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModel, AutoTokenizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
# Load swissBERT for sentence embeddings model | |
model_name = "jgrosjean-mathesis/sentence-swissbert" | |
model = AutoModel.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def generate_sentence_embedding(sentence, language): | |
# Set adapter to specified language | |
if "de" in language: | |
model.set_default_language("de_CH") | |
if "fr" in language: | |
model.set_default_language("fr_CH") | |
if "it" in language: | |
model.set_default_language("it_CH") | |
if "rm" in language: | |
model.set_default_language("rm_CH") | |
# Tokenize input sentence | |
inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt", max_length=512) | |
# Take tokenized input and pass it through the model | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Extract sentence embeddings via mean pooling | |
token_embeddings = outputs.last_hidden_state | |
attention_mask = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float() | |
sum_embeddings = torch.sum(token_embeddings * attention_mask, 1) | |
sum_mask = torch.clamp(attention_mask.sum(1), min=1e-9) | |
embedding = sum_embeddings / sum_mask | |
return embedding | |
def calculate_cosine_similarities(source_sentence, source_language, target_sentence_1, target_language_1, target_sentence_2, target_language_2, target_sentence_3, target_language_3): | |
source_embedding = generate_sentence_embedding(source_sentence, source_language) | |
target_embedding_1 = generate_sentence_embedding(target_sentence_1, target_language_1) | |
target_embedding_2 = generate_sentence_embedding(target_sentence_2, target_language_2) | |
target_embedding_3 = generate_sentence_embedding(target_sentence_3, target_language_3) | |
cosine_score_1 = cosine_similarity(source_embedding, target_embedding_1) | |
cosine_score_2 = cosine_similarity(source_embedding, target_embedding_2) | |
cosine_score_3 = cosine_similarity(source_embedding, target_embedding_3) | |
cosine_scores = { | |
target_sentence_1: cosine_score_1[0][0], | |
target_sentence_2: cosine_score_2[0][0], | |
target_sentence_3: cosine_score_3[0][0] | |
} | |
cosine_scores_dict = dict(sorted(cosine_scores.items(), key=lambda item: item[1], reverse=True)) | |
cosine_scores_output = "" | |
for key, value in cosine_scores_dict.items(): | |
cosine_scores_output += key + ": " + str(value) + "\n" | |
cosine_scores_output = "**" + cosine_scores_output.replace("\n", "**\n", 1) | |
return cosine_scores_output | |
def main(): | |
demo = gr.Interface( | |
fn=calculate_cosine_similarities, | |
inputs=[ | |
gr.Textbox(lines=1, placeholder="Der Zug fährt um 9 Uhr in Zürich ab.", label="source sentence"), | |
gr.Dropdown(["de", "fr", "it", "rm"], value="de", label="language"), | |
gr.Textbox(lines=1, placeholder="Le train arrive à Lausanne à 11 heures.", label="target sentence 1"), | |
gr.Dropdown(["de", "fr", "it", "rm"], value="fr", label="language"), | |
gr.Textbox(lines=1, placeholder="Alla stazione di Lugano ci sono diversi binari.", label="target sentence 2"), | |
gr.Dropdown(["de", "fr", "it", "rm"], value="it", label="language"), | |
gr.Textbox(lines=1, placeholder="A Cuera van biars trens ellas muntognas.", label="target sentence 3"), | |
gr.Dropdown(["de", "fr", "it", "rm"], value="rm", label="language") | |
], | |
outputs= gr.Textbox(label="Cosine similarity scores", type="text", lines=3) | |
) | |
demo.launch(share=True) | |
if __name__ == "__main__": | |
main() | |