IBounhas commited on
Commit
fe4625b
·
1 Parent(s): 9ceb0bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -3
app.py CHANGED
@@ -1,15 +1,70 @@
1
  import gradio as gr
 
2
 
3
  # Define a function that takes a text input and returns the result
4
- def analyze_text(input_text):
5
  # Your processing or model inference code here
6
- result = f"You entered: {input_text}"
7
  return result
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Create a Gradio interface with a text input zone
10
  iface = gr.Interface(
11
  fn=analyze_text, # The function to be called with user input
12
- inputs=gr.Textbox(), # Textbox component for text input
13
  outputs="text" # Display the result as text
14
  )
15
 
 
1
  import gradio as gr
2
+ import torch
3
 
4
  # Define a function that takes a text input and returns the result
5
+ def analyze_text(input):
6
  # Your processing or model inference code here
7
+ result = predict_similarity(input)
8
  return result
9
 
10
+ param_model_name="CAMeL-Lab/bert-base-arabic-camelbert-msa-sixteenth"
11
+ tokenizer = CustomBertTokenizer.from_pretrained(param_model_name)
12
+ class BertForSTS(torch.nn.Module):
13
+
14
+ def __init__(self):
15
+ super(BertForSTS, self).__init__()
16
+ #self.bert = models.Transformer('bert-base-uncased', max_seq_length=128)
17
+ #self.bert = AutoModelForSequenceClassification.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sixteenth")
18
+ self.bert = models.Transformer(param_model_name, max_seq_length=param_max_length)
19
+
20
+ if(param_freeze):
21
+ for p in self.bert.parameters():
22
+ p.requires_grad = False
23
+ dimension= self.bert.get_word_embedding_dimension()
24
+ #print(dimension)
25
+ self.pooling_layer = models.Pooling(dimension)
26
+ self.dropout = torch.nn.Dropout(0.1)
27
+
28
+ # relu activation function
29
+ self.relu = torch.nn.ReLU()
30
+
31
+ # dense layer 1
32
+ self.fc1 = torch.nn.Linear(dimension,512)
33
+
34
+ # dense layer 2 (Output layer)
35
+ self.fc2 = torch.nn.Linear(512,512)
36
+ #self.pooling_layer = models.Pooling(self.bert.config.hidden_size)
37
+ self.sts_bert = SentenceTransformer(modules=[self.bert,self.pooling_layer, self.fc1])
38
+ #self.sts_bert = SentenceTransformer(modules=[self.bert,self.pooling_layer, self.fc1, self.relu, self.dropout,self.fc2])
39
+ def forward(self, input_data):
40
+ #print(input_data)
41
+ x=self.bert(input_data)
42
+ x=self.pooling_layer(x)
43
+ x=self.fc1(x['sentence_embedding'])
44
+ x = self.relu(x)
45
+ x = self.dropout(x)
46
+ #x = self.fc2(x)
47
+
48
+ return x
49
+ model_load_path = "IBounhas/riadh/bert-sts-15.pt"
50
+ model = BertForSTS()
51
+ model.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu')))
52
+ model.to(device)
53
+
54
+ def predict_similarity(sentence_pair):
55
+ test_input = tokenizer(sentence_pair, padding='max_length', max_length = param_max_length, truncation=True, return_tensors="pt").to(device)
56
+ test_input['input_ids'] = test_input['input_ids']
57
+ test_input['attention_mask'] = test_input['attention_mask']
58
+ del test_input['token_type_ids']
59
+ output = model(test_input)
60
+ sim = torch.nn.functional.cosine_similarity(output[0], output[1], dim=0).item()*2-1
61
+
62
+ return sim
63
+
64
  # Create a Gradio interface with a text input zone
65
  iface = gr.Interface(
66
  fn=analyze_text, # The function to be called with user input
67
+ inputs=[gr.Textbox(), gr.Textbox()],
68
  outputs="text" # Display the result as text
69
  )
70