dejanseo commited on
Commit
b857ab4
1 Parent(s): 108c945

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AlbertTokenizer, AlbertForSequenceClassification, AlbertConfig
4
+ import plotly.graph_objects as go
5
+
6
+ # URL of the logo
7
+ logo_url = "https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png"
8
+
9
+ # Display the logo at the top using st.logo
10
+ st.logo(logo_url, link="https://dejan.ai")
11
+
12
+ # Streamlit app title and description
13
+ st.title("Search Query Form Classifier")
14
+ st.write("Ambiguous search queries are candidates for query expansion. Our model identifies such queries with an 80 percent accuracy and is deployed in a batch processing pipeline directly connected with Google Search Console API. In this demo you can test the model capability by testing individual queries.")
15
+ st.write("Enter a query to check if it's well-formed:")
16
+
17
+ # Load the model and tokenizer from the /model/ directory
18
+ model_dir = 'model'
19
+ tokenizer = AlbertTokenizer.from_pretrained(model_dir)
20
+ config = AlbertConfig.from_pretrained(model_dir)
21
+ model = AlbertForSequenceClassification.from_pretrained(model_dir, config=config)
22
+
23
+ # Set the model to evaluation mode
24
+ model.eval()
25
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ model.to(device)
27
+
28
+ # User input
29
+ user_input = st.text_input("Query:", "What is?")
30
+ st.write("Developed by [Dejan AI](https://dejan.ai/blog/search-query-quality-classifier/)")
31
+
32
+ def classify_query(query):
33
+ # Tokenize input
34
+ inputs = tokenizer.encode_plus(
35
+ query,
36
+ add_special_tokens=True,
37
+ max_length=32,
38
+ padding='max_length',
39
+ truncation=True,
40
+ return_attention_mask=True,
41
+ return_tensors='pt'
42
+ )
43
+
44
+ input_ids = inputs['input_ids'].to(device)
45
+ attention_mask = inputs['attention_mask'].to(device)
46
+
47
+ # Perform inference
48
+ with torch.no_grad():
49
+ outputs = model(input_ids, attention_mask=attention_mask)
50
+ logits = outputs.logits
51
+ softmax_scores = torch.softmax(logits, dim=1).cpu().numpy()[0]
52
+ confidence = softmax_scores[1] * 100 # Confidence for well-formed class
53
+
54
+ return confidence
55
+
56
+ # Check and display classification
57
+ if user_input:
58
+ confidence = classify_query(user_input)
59
+
60
+ # Plotly gauge
61
+ fig = go.Figure(go.Indicator(
62
+ mode="gauge+number",
63
+ value=confidence,
64
+ title={'text': "Well-formedness Confidence"},
65
+ gauge={
66
+ 'axis': {'range': [0, 100]},
67
+ 'bar': {'color': "darkblue"},
68
+ 'steps': [
69
+ {'range': [0, 50], 'color': "red"},
70
+ {'range': [50, 100], 'color': "green"}
71
+ ],
72
+ 'threshold': {
73
+ 'line': {'color': "black", 'width': 4},
74
+ 'thickness': 0.75,
75
+ 'value': confidence
76
+ }
77
+ }
78
+ ))
79
+
80
+ st.plotly_chart(fig)
81
+
82
+ if confidence >= 50:
83
+ st.success(f"The query is likely well-formed with {confidence:.2f}% confidence.")
84
+ else:
85
+ st.error(f"The query is likely not well-formed with {100 - confidence:.2f}% confidence.")