File size: 3,862 Bytes
556bbce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f98ea7f
556bbce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f98ea7f
 
 
 
 
 
 
 
 
4b9a203
 
 
 
 
 
556bbce
 
 
 
4b9a203
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity

# Load swissBERT for sentence embeddings model
model_name = "jgrosjean-mathesis/sentence-swissbert"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def generate_sentence_embedding(sentence, language):
    if "de" in language:
        model.set_default_language("de_CH")
    if "fr" in language:
        model.set_default_language("fr_CH")
    if "it" in language:
        model.set_default_language("it_CH")
    if "rm" in language:
        model.set_default_language("rm_CH")
    inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors="pt", max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    token_embeddings = outputs.last_hidden_state
    attention_mask = inputs['attention_mask'].unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * attention_mask, 1)
    sum_mask = torch.clamp(attention_mask.sum(1), min=1e-9)
    embedding = sum_embeddings / sum_mask
    return embedding

def calculate_cosine_similarities(source_sentence, source_language, target_sentence_1, target_language_1, target_sentence_2, target_language_2, target_sentence_3, target_language_3):
    source_embedding = generate_sentence_embedding(source_sentence, source_language)
    target_embedding_1 = generate_sentence_embedding(target_sentence_1, target_language_1)
    target_embedding_2 = generate_sentence_embedding(target_sentence_2, target_language_2)
    target_embedding_3 = generate_sentence_embedding(target_sentence_3, target_language_3)
    cosine_score_1 = cosine_similarity(source_embedding, target_embedding_1)
    cosine_score_2 = cosine_similarity(source_embedding, target_embedding_2)
    cosine_score_3 = cosine_similarity(source_embedding, target_embedding_3)
    cosine_scores = {
        target_sentence_1: cosine_score_1[0][0],
        target_sentence_2: cosine_score_2[0][0],
        target_sentence_3: cosine_score_3[0][0]
    }
    cosine_scores_dict = dict(sorted(cosine_scores.items(), key=lambda item: item[1], reverse=True))
    cosine_scores_output = ""
    for key, value in cosine_scores_dict.items():
        cosine_scores_output += key + ": " + str(value) + "\n"
    cosine_scores_output = "**" + cosine_scores_output.replace("\n", "**\n", 1)
    return cosine_scores_output

def main():   
    demo = gr.Interface(
        fn=calculate_cosine_similarities,
        inputs=[
            gr.Textbox(lines=1, placeholder="Enter source sentence", label="Source Sentence"),
            gr.Dropdown(["de", "fr", "it", "rm"], label="Source Language"),
            gr.Textbox(lines=1, placeholder="Enter target sentence 1", label="Target Sentence 1"),
            gr.Dropdown(["de", "fr", "it", "rm"], label="Target Language 1"),
            gr.Textbox(lines=1, placeholder="Enter target sentence 2", label="Target Sentence 2"),
            gr.Dropdown(["de", "fr", "it", "rm"], label="Target Language 2"),
            gr.Textbox(lines=1, placeholder="Enter target sentence 3", label="Target Sentence 3"),
            gr.Dropdown(["de", "fr", "it", "rm"], label="Target Language 3")
            ],
        outputs= gr.Textbox(label="Cosine Similarity Scores", type="text", lines=3),
        title="Sentence Similarity Calculator",
        description="Enter a source sentence and up to three target sentences to calculate their cosine similarity.",
        examples=[
            ["Der Zug fährt um 9 Uhr in Zürich ab.", "de", "Le train arrive à Lausanne à 11 heures.", "fr", "Alla stazione di Lugano ci sono diversi binari.", "it", "A Cuera van biars trens ellas muntognas.", "rm"]
        ]
    )
    demo.launch(share=True)

if __name__ == "__main__":
    main()