File size: 3,046 Bytes
bf4f788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
import torch
from transformers import AlbertTokenizer, AlbertForSequenceClassification, AlbertConfig
import plotly.graph_objects as go

# URL of the logo
logo_url = "https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png"

# Display the logo at the top using st.logo
st.logo(logo_url, link="https://dejan.ai")

# Streamlit app title and description
st.title("Search Query Form Classifier")
st.write("Ambiguous search queries are candidates for query expansion. Our model identifies such queries with an 80 percent accuracy and is deployed in a batch processing pipeline directly connected with Google Search Console API. In this demo you can test the model capability by testing individual queries.")
st.write("Enter a query to check if it's well-formed:")

# Load the model and tokenizer from the Hugging Face Model Hub
model_name = 'dejanseo/Query-Quality-Classifier'
tokenizer = AlbertTokenizer.from_pretrained(model_name)
model = AlbertForSequenceClassification.from_pretrained(model_name)

# Set the model to evaluation mode
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# User input
user_input = st.text_input("Query:", "What is?")
st.write("Developed by [Dejan AI](https://dejan.ai/blog/search-query-quality-classifier/)")

def classify_query(query):
    # Tokenize input
    inputs = tokenizer.encode_plus(
        query,
        add_special_tokens=True,
        max_length=32,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )

    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    # Perform inference
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        softmax_scores = torch.softmax(logits, dim=1).cpu().numpy()[0]
        confidence = softmax_scores[1] * 100  # Confidence for well-formed class

    return confidence

# Check and display classification
if user_input:
    confidence = classify_query(user_input)

    # Plotly gauge
    fig = go.Figure(go.Indicator(
        mode="gauge+number",
        value=confidence,
        title={'text': "Well-formedness Confidence"},
        gauge={
            'axis': {'range': [0, 100]},
            'bar': {'color': "darkblue"},
            'steps': [
                {'range': [0, 50], 'color': "red"},
                {'range': [50, 100], 'color': "green"}
            ],
            'threshold': {
                'line': {'color': "black", 'width': 4},
                'thickness': 0.75,
                'value': confidence
            }
        }
    ))

    st.plotly_chart(fig)

    if confidence >= 50:
        st.success(f"The query is likely well-formed with {confidence:.2f}% confidence.")
    else:
        st.error(f"The query is likely not well-formed with {100 - confidence:.2f}% confidence.")