alperugurcan's picture
Update app.py
413c8a1 verified
import gradio as gr
import torch
from transformers import DistilBertTokenizer, DistilBertModel
class SimilarityPredictor:
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Use the base model instead of custom model
self.model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(self.device)
self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
self.head = torch.nn.Sequential(torch.nn.Linear(768, 1), torch.nn.Sigmoid()).to(self.device)
def predict(self, anchor, target):
self.model.eval()
with torch.no_grad():
encoded = self.tokenizer(
[anchor],
[target],
padding=True,
truncation=True,
max_length=64,
return_tensors='pt'
).to(self.device)
output = self.head(self.model(**encoded)[0][:,0,:]).squeeze()
return float(output)
predictor = SimilarityPredictor()
example_pairs = [
["mobile phone", "cellphone"],
["artificial intelligence", "machine learning"],
["electric vehicle", "battery powered car"],
["wireless communication", "radio transmission"],
["solar panel", "photovoltaic cell"]
]
def predict_similarity(anchor, target):
score = predictor.predict(anchor, target)
return round(score, 3)
iface = gr.Interface(
fn=predict_similarity,
inputs=[
gr.Textbox(label="Anchor Phrase", placeholder="Enter first phrase..."),
gr.Textbox(label="Target Phrase", placeholder="Enter second phrase...")
],
outputs=gr.Number(label="Similarity Score (0-1)"),
title="Patent Phrase Similarity Checker",
description="Compare the similarity between two patent phrases (0: Different, 1: Identical)",
examples=example_pairs,
theme="huggingface"
)
iface.launch()