muzamilhxmi commited on
Commit
bf265e0
·
verified ·
1 Parent(s): 45366b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -11
app.py CHANGED
@@ -1,20 +1,35 @@
1
  import streamlit as st
2
- from transformers import pipeline
 
 
 
 
 
 
 
3
 
4
  # Create a pipeline for text classification
5
- pipe = pipeline("text-classification", model="blaikhole/distilbert-review-bug-classifier")
6
 
7
- # Streamlit app structure
8
- st.title("Review bug classification demo.")
9
- st.write("Enter some text and the model will predict bug category:")
10
 
11
- # Input text from the user
12
- user_input = st.text_input("Input Text:")
13
 
14
- # When the button is clicked, classify the input
15
  if st.button("Classify"):
16
  if user_input:
17
- result = pipe(user_input)
18
- st.write(result)
 
 
 
 
 
 
 
19
  else:
20
- st.write("Please enter some text.")
 
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoConfig
3
+
4
+ # Model Name (Replace with actual model)
5
+ MODEL_NAME = "blaikhole/distilbert-review-bug-classifier"
6
+
7
+ # Load model config to get label mapping
8
+ config = AutoConfig.from_pretrained(MODEL_NAME)
9
+ id2label = config.id2label
10
 
11
  # Create a pipeline for text classification
12
+ pipe = pipeline("text-classification", model=MODEL_NAME)
13
 
14
+ # Streamlit app UI
15
+ st.title("Review Bug Classification Demo 🐞")
16
+ st.write("Enter some text and the model will predict the bug category.")
17
 
18
+ # User Input
19
+ user_input = st.text_area("Input Text:", height=150)
20
 
21
+ # Prediction
22
  if st.button("Classify"):
23
  if user_input:
24
+ result = pipe(user_input, return_all_scores=True)[0] # Get all scores
25
+
26
+ # Convert "LABEL_n" to actual class names
27
+ predictions = {id2label[int(res['label'].replace('LABEL_', ''))]: res['score'] for res in result}
28
+ top_label = max(predictions, key=predictions.get)
29
+
30
+ # Display results
31
+ st.write(f"### 🏆 Predicted Category: `{top_label}`")
32
+ st.json(predictions) # Show confidence scores
33
  else:
34
+ st.warning("⚠️ Please enter some text.")
35
+