emiliosheinz commited on
Commit
66f1e76
·
1 Parent(s): 4c9ba47

use fine tuned model

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
 
4
- # load the pre-trained model and tokenizer
5
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-multilingual-cased")
6
- model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-multilingual-cased")
7
 
8
  # set the app title
9
  st.title("Sentence Similarity Checker")
@@ -14,15 +12,12 @@ sentence2 = st.text_input("Enter the second sentence:")
14
 
15
  # check if both sentences are not empty
16
  if sentence1 and sentence2:
17
- # tokenize the sentences and get the output logits for the sentence pair classification task
18
- inputs = tokenizer(sentence1, sentence2, padding=True, truncation=True, max_length=250, return_tensors="pt")
19
- outputs = model(**inputs).logits
20
 
21
- # calculate the softmax probabilities for the two classes (similar or dissimilar)
22
- probs = outputs.softmax(dim=1)
23
-
24
- # the probability of the sentences being similar is the second element of the output array
25
- similarity_score = probs[0][1].item()
26
 
27
  # display the similarity score to the user
28
  st.write("Similarity score:", similarity_score)
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
 
4
+ model = AutoModelForSequenceClassification.from_pretrained("sentence-transformers/all-distilroberta-v1")
 
 
5
 
6
  # set the app title
7
  st.title("Sentence Similarity Checker")
 
12
 
13
  # check if both sentences are not empty
14
  if sentence1 and sentence2:
15
+ # encode the sentences into embeddings
16
+ embeddings1 = model.encode(sentence1, convert_to_tensor=True)
17
+ embeddings2 = model.encode(sentence2, convert_to_tensor=True)
18
 
19
+ # calculate the cosine similarity between the embeddings
20
+ similarity_score = float(embeddings1 @ embeddings2.T)
 
 
 
21
 
22
  # display the similarity score to the user
23
  st.write("Similarity score:", similarity_score)