emrecan commited on
Commit
7ae2fd5
·
1 Parent(s): 7e9d867

create a draft of the app

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ import plotly.express as px
4
+ from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
5
+
6
+ st.title("Zero-shot Turkish Text Classification")
7
+
8
+ method_selection = st.radio(
9
+ "Select a zero-shot classification method.",
10
+ [
11
+ METHOD_OPTIONS["nli"],
12
+ METHOD_OPTIONS["nsp"],
13
+ ],
14
+ )
15
+
16
+ if method_selection == METHOD_OPTIONS["nli"]:
17
+ model = st.selectbox(
18
+ "Select a natural language inference model.", NLI_MODEL_OPTIONS
19
+ )
20
+ if method_selection == METHOD_OPTIONS["nsp"]:
21
+ model = st.selectbox(
22
+ "Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS
23
+ )
24
+
25
+ st.header("Configure prompts and labels")
26
+ col1, col2 = st.columns(2)
27
+ col1.subheader("Candidate labels")
28
+ labels = col1.text_area(
29
+ label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.",
30
+ value="spor,dünya,siyaset,ekonomi,kültür ve sanat",
31
+ height=10,
32
+ placeholder="Enter a set of comma separated labels. (eg. spor,dünya,siyaset,ekonomi,kültür ve sanat)",
33
+ )
34
+ col2.subheader("Prompt template")
35
+ prompt_template = col2.text_area(
36
+ label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.",
37
+ value="Bu metin {} kategorisine aittir",
38
+ height=10,
39
+ )
40
+
41
+ col1.header("Make predictions")
42
+ col2.header("")
43
+ col1.text_area("", value="", placeholder="Enter some text to classify.")
44
+ col1.button("Predict")
45
+
46
+ probs = [0.86, 0.10, 0.01, 0.02, 0.01]
47
+ data = pd.DataFrame({"labels": labels.split(","), "probability": probs}).sort_values(
48
+ by="probability", ascending=False
49
+ )
50
+ chart = px.bar(
51
+ data,
52
+ x="probability",
53
+ y="labels",
54
+ color="labels",
55
+ orientation="h",
56
+ height=290,
57
+ width=500,
58
+ ).update_layout(
59
+ {
60
+ "xaxis": {"title": "probability", "visible": True, "showticklabels": True},
61
+ "yaxis": {"title": None, "visible": True, "showticklabels": True},
62
+ "margin": dict(
63
+ l=10, # left
64
+ r=10, # right
65
+ t=50, # top
66
+ b=10, # bottom
67
+ ),
68
+ "showlegend": False,
69
+ }
70
+ )
71
+ col2.plotly_chart(chart)