Riadh / app.py
IBounhas's picture
Update app.py
1826c87
raw
history blame
2.84 kB
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
from sentence_transformers import SentenceTransformer, models
param_max_length=256
# Define a function that takes a text input and returns the result
def analyze_text(input):
# Your processing or model inference code here
result = predict_similarity(input)
return result
param_model_name="CAMeL-Lab/bert-base-arabic-camelbert-msa-sixteenth"
tokenizer = AutoTokenizer.from_pretrained(param_model_name)
class BertForSTS(torch.nn.Module):
def __init__(self):
super(BertForSTS, self).__init__()
#self.bert = models.Transformer('bert-base-uncased', max_seq_length=128)
#self.bert = AutoModelForSequenceClassification.from_pretrained("CAMeL-Lab/bert-base-arabic-camelbert-msa-sixteenth")
self.bert = models.Transformer(param_model_name, max_seq_length=param_max_length)
dimension= self.bert.get_word_embedding_dimension()
#print(dimension)
self.pooling_layer = models.Pooling(dimension)
self.dropout = torch.nn.Dropout(0.1)
# relu activation function
self.relu = torch.nn.ReLU()
# dense layer 1
self.fc1 = torch.nn.Linear(dimension,512)
# dense layer 2 (Output layer)
self.fc2 = torch.nn.Linear(512,512)
#self.pooling_layer = models.Pooling(self.bert.config.hidden_size)
self.sts_bert = SentenceTransformer(modules=[self.bert,self.pooling_layer, self.fc1])
#self.sts_bert = SentenceTransformer(modules=[self.bert,self.pooling_layer, self.fc1, self.relu, self.dropout,self.fc2])
def forward(self, input_data):
#print(input_data)
x=self.bert(input_data)
x=self.pooling_layer(x)
x=self.fc1(x['sentence_embedding'])
x = self.relu(x)
x = self.dropout(x)
#x = self.fc2(x)
return x
model_load_path = "IBounhas/riadh/bert-sts-15.pt"
model = BertForSTS()
model.load_state_dict(torch.load(model_load_path, map_location=torch.device('cpu')))
model.to(device)
def predict_similarity(sentence_pair):
test_input = tokenizer(sentence_pair, padding='max_length', max_length = param_max_length, truncation=True, return_tensors="pt").to(device)
test_input['input_ids'] = test_input['input_ids']
test_input['attention_mask'] = test_input['attention_mask']
del test_input['token_type_ids']
output = model(test_input)
sim = torch.nn.functional.cosine_similarity(output[0], output[1], dim=0).item()*2-1
return sim
# Create a Gradio interface with a text input zone
iface = gr.Interface(
fn=analyze_text, # The function to be called with user input
inputs=[gr.Textbox(), gr.Textbox()],
outputs="text" # Display the result as text
)
# Launch the Gradio interface
iface.launch()