|
import streamlit as st |
|
import torch |
|
from transformers import AlbertTokenizer, AlbertForSequenceClassification, AlbertConfig |
|
import plotly.graph_objects as go |
|
|
|
|
|
logo_url = "https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png" |
|
|
|
|
|
st.logo(logo_url, link="https://dejan.ai") |
|
|
|
|
|
st.title("Search Query Form Classifier") |
|
st.write("Ambiguous search queries are candidates for query expansion. Our model identifies such queries with an 80 percent accuracy and is deployed in a batch processing pipeline directly connected with Google Search Console API. In this demo you can test the model capability by testing individual queries.") |
|
st.write("Enter a query to check if it's well-formed:") |
|
|
|
|
|
model_name = 'dejanseo/Query-Quality-Classifier' |
|
tokenizer = AlbertTokenizer.from_pretrained(model_name) |
|
model = AlbertForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
model.eval() |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model.to(device) |
|
|
|
|
|
user_input = st.text_input("Query:", "What is?") |
|
st.write("Developed by [Dejan AI](https://dejan.ai/blog/search-query-quality-classifier/)") |
|
|
|
def classify_query(query): |
|
|
|
inputs = tokenizer.encode_plus( |
|
query, |
|
add_special_tokens=True, |
|
max_length=32, |
|
padding='max_length', |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors='pt' |
|
) |
|
|
|
input_ids = inputs['input_ids'].to(device) |
|
attention_mask = inputs['attention_mask'].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
softmax_scores = torch.softmax(logits, dim=1).cpu().numpy()[0] |
|
confidence = softmax_scores[1] * 100 |
|
|
|
return confidence |
|
|
|
|
|
if user_input: |
|
confidence = classify_query(user_input) |
|
|
|
|
|
fig = go.Figure(go.Indicator( |
|
mode="gauge+number", |
|
value=confidence, |
|
title={'text': "Well-formedness Confidence"}, |
|
gauge={ |
|
'axis': {'range': [0, 100]}, |
|
'bar': {'color': "darkblue"}, |
|
'steps': [ |
|
{'range': [0, 50], 'color': "red"}, |
|
{'range': [50, 100], 'color': "green"} |
|
], |
|
'threshold': { |
|
'line': {'color': "black", 'width': 4}, |
|
'thickness': 0.75, |
|
'value': confidence |
|
} |
|
} |
|
)) |
|
|
|
st.plotly_chart(fig) |
|
|
|
if confidence >= 50: |
|
st.success(f"The query is likely well-formed with {confidence:.2f}% confidence.") |
|
else: |
|
st.error(f"The query is likely not well-formed with {100 - confidence:.2f}% confidence.") |
|
|