dejanseo commited on
Commit
8418d26
1 Parent(s): 9da4402

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -61
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AlbertTokenizer, AlbertForSequenceClassification, AlbertConfig
4
  import plotly.graph_objects as go
5
 
6
  # URL of the logo
@@ -11,7 +11,9 @@ st.logo(logo_url, link="https://dejan.ai")
11
 
12
  # Streamlit app title and description
13
  st.title("Search Query Form Classifier")
14
- st.write("Ambiguous search queries are candidates for query expansion. Our model identifies such queries with an 80 percent accuracy and is deployed in a batch processing pipeline directly connected with Google Search Console API. In this demo you can test the model capability by testing individual queries.")
 
 
15
  st.write("Enter a query to check if it's well-formed:")
16
 
17
  # Load the model and tokenizer from the Hugging Face Model Hub
@@ -19,66 +21,120 @@ model_name = 'dejanseo/Query-Quality-Classifier'
19
  tokenizer = AlbertTokenizer.from_pretrained(model_name)
20
  model = AlbertForSequenceClassification.from_pretrained(model_name)
21
 
22
- # Set the model to evaluation mode
23
  model.eval()
24
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
  model.to(device)
26
 
27
- # User input
28
- user_input = st.text_input("Query:", "What is?")
29
- st.write("Developed by [Dejan AI](https://dejan.ai/blog/search-query-quality-classifier/)")
30
-
31
- def classify_query(query):
32
- # Tokenize input
33
- inputs = tokenizer.encode_plus(
34
- query,
35
- add_special_tokens=True,
36
- max_length=32,
37
- padding='max_length',
38
- truncation=True,
39
- return_attention_mask=True,
40
- return_tensors='pt'
41
- )
42
-
43
- input_ids = inputs['input_ids'].to(device)
44
- attention_mask = inputs['attention_mask'].to(device)
45
-
46
- # Perform inference
47
- with torch.no_grad():
48
- outputs = model(input_ids, attention_mask=attention_mask)
49
- logits = outputs.logits
50
- softmax_scores = torch.softmax(logits, dim=1).cpu().numpy()[0]
51
- confidence = softmax_scores[1] * 100 # Confidence for well-formed class
52
-
53
- return confidence
54
-
55
- # Check and display classification
56
- if user_input:
57
- confidence = classify_query(user_input)
58
-
59
- # Plotly gauge
60
- fig = go.Figure(go.Indicator(
61
- mode="gauge+number",
62
- value=confidence,
63
- title={'text': "Well-formedness Confidence"},
64
- gauge={
65
- 'axis': {'range': [0, 100]},
66
- 'bar': {'color': "darkblue"},
67
- 'steps': [
68
- {'range': [0, 50], 'color': "red"},
69
- {'range': [50, 100], 'color': "green"}
70
- ],
71
- 'threshold': {
72
- 'line': {'color': "black", 'width': 4},
73
- 'thickness': 0.75,
74
- 'value': confidence
75
- }
76
- }
77
- ))
78
-
79
- st.plotly_chart(fig)
80
-
81
- if confidence >= 50:
82
- st.success(f"The query is likely well-formed with {confidence:.2f}% confidence.")
83
- else:
84
- st.error(f"The query is likely not well-formed with {100 - confidence:.2f}% confidence.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
+ from transformers import AlbertTokenizer, AlbertForSequenceClassification
4
  import plotly.graph_objects as go
5
 
6
  # URL of the logo
 
11
 
12
  # Streamlit app title and description
13
  st.title("Search Query Form Classifier")
14
+ st.write(
15
+ "Ambiguous search queries are candidates for query expansion. Our model identifies such queries with an 80 percent accuracy and is deployed in a batch processing pipeline directly connected with Google Search Console API. In this demo you can test the model capability by testing individual queries."
16
+ )
17
  st.write("Enter a query to check if it's well-formed:")
18
 
19
  # Load the model and tokenizer from the Hugging Face Model Hub
 
21
  tokenizer = AlbertTokenizer.from_pretrained(model_name)
22
  model = AlbertForSequenceClassification.from_pretrained(model_name)
23
 
24
+ # Set the model to evaluation mode
25
  model.eval()
26
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
  model.to(device)
28
 
29
+ # Create tabs for single and bulk queries
30
+ tab1, tab2 = st.tabs(["Single Query", "Bulk Query"])
31
+
32
+ with tab1:
33
+ user_input = st.text_input("Query:", "What is?")
34
+ #st.write("Developed by [Dejan AI](https://dejan.ai/blog/search-query-quality-classifier/)")
35
+
36
+ def classify_query(query):
37
+ # Tokenize input
38
+ inputs = tokenizer.encode_plus(
39
+ query,
40
+ add_special_tokens=True,
41
+ max_length=32,
42
+ padding='max_length',
43
+ truncation=True,
44
+ return_attention_mask=True,
45
+ return_tensors='pt'
46
+ )
47
+
48
+ input_ids = inputs['input_ids'].to(device)
49
+ attention_mask = inputs['attention_mask'].to(device)
50
+
51
+ # Perform inference
52
+ with torch.no_grad():
53
+ outputs = model(input_ids, attention_mask=attention_mask)
54
+ logits = outputs.logits
55
+ softmax_scores = torch.softmax(logits, dim=1).cpu().numpy()[0]
56
+ confidence = softmax_scores[1] * 100 # Confidence for well-formed class
57
+
58
+ return confidence
59
+
60
+ # Function to determine color based on confidence
61
+ def get_color(confidence):
62
+ if confidence < 50:
63
+ return 'rgba(255, 51, 0, 0.8)' # Red
64
+ else:
65
+ return 'rgba(57, 172, 57, 0.8)' # Green
66
+
67
+ # Check and display classification for single query
68
+ if user_input:
69
+ confidence = classify_query(user_input)
70
+
71
+ # Plotly grey placeholder bar with dynamic color fill
72
+ fig = go.Figure()
73
+
74
+ # Placeholder grey bar
75
+ fig.add_trace(go.Bar(
76
+ x=[100],
77
+ y=['Well-formedness Factor'],
78
+ orientation='h',
79
+ marker=dict(
80
+ color='lightgrey'
81
+ ),
82
+ width=0.8
83
+ ))
84
+
85
+ # Colored bar based on confidence
86
+ fig.add_trace(go.Bar(
87
+ x=[confidence],
88
+ y=['Well-formedness Factor'],
89
+ orientation='h',
90
+ marker=dict(
91
+ color=get_color(confidence)
92
+ ),
93
+ width=0.8
94
+ ))
95
+
96
+ fig.update_layout(
97
+ xaxis=dict(range=[0, 100], title='Well-formedness Factor'),
98
+ yaxis=dict(showticklabels=False),
99
+ width=600,
100
+ height=250, # Increase height for better visibility
101
+ title_text='Well-formedness Factor',
102
+ plot_bgcolor='rgba(0,0,0,0)',
103
+ showlegend=False
104
+ )
105
+
106
+ st.plotly_chart(fig)
107
+
108
+ if confidence >= 50:
109
+ st.success(f"Query Score: {confidence:.2f}% Most likely doesn't require query expansion.")
110
+ st.subheader(f":sparkles: What's next?", divider="gray")
111
+ st.write("Connect with Google Search Console, Semrush, Ahrefs or any other search query source API and detect all queries which could benefit from expansion.")
112
+ st.write("[Engage our team](https://dejan.ai/call/) if you'd like us to do this for you.")
113
+ else:
114
+ st.error(f"The query is likely not well-formed with a score of {100 - confidence:.2f}% and most likely requires query expansion.")
115
+ st.subheader(f":sparkles: What's next?", divider="gray")
116
+ st.write("Connect with Google Search Console, Semrush, Ahrefs or any other search query source API and detect all queries which could benefit from expansion.")
117
+ st.write("[Engage our team](https://dejan.ai/call/) if you'd like us to do this for you.")
118
+
119
+ with tab2:
120
+ st.write("Paste multiple queries line-separated (no headers or extra data):")
121
+ bulk_input = st.text_area("Bulk Queries:", height=200)
122
+
123
+ if bulk_input:
124
+ bulk_queries = bulk_input.splitlines()
125
+ st.write("Processing queries...")
126
+
127
+ # Classify each query in bulk input
128
+ results = [(query, classify_query(query)) for query in bulk_queries]
129
+
130
+ # Display results in a table
131
+ for query, confidence in results:
132
+ st.write(f"Query: {query} - Score: {confidence:.2f}%")
133
+ if confidence >= 50:
134
+ st.success("Well-formed")
135
+ else:
136
+ st.error("Not well-formed")
137
+
138
+ st.subheader(f":sparkles: What's next?", divider="gray")
139
+ st.write("Connect with Google Search Console, Semrush, Ahrefs or any other search query source API and detect all queries which could benefit from expansion.")
140
+ st.write("[Engage our team](https://dejan.ai/call/) if you'd like us to do this for you.")