Ransaka commited on
Commit
cb4cfa0
·
1 Parent(s): f601486

Added app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit app for zero-shot classification."""
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import altair as alt
5
+ from transformers import pipeline
6
+ from transformers import AutoTokenizer
7
+ import warnings
8
+ warnings.filterwarnings('ignore')
9
+
10
+ # set up altair theme
11
+ font = 'NotoSansSinhala.ttf'
12
+ font_color = '#858991'
13
+ font_title = '#858991'
14
+ font_axis = '#858991'
15
+
16
+ TARGETS = ['දේශපාලන', 'ආර්ථික', 'සෞඛ්ය', 'අපරාධ', 'තාක්ෂණ', 'ක්රීඩා', 'විනෝද', 'සමාජ']
17
+ SIN_2_ENG = {
18
+ 'දේශපාලන':'Political',
19
+ 'ආර්ථික':'Economic',
20
+ 'සෞඛ්ය':'Health',
21
+ 'අපරාධ':'Crime',
22
+ 'තාක්ෂණ':'Technology',
23
+ 'ක්රීඩා':'Sports',
24
+ 'විනෝද':'Entertainment',
25
+ 'සමාජ':'Social'
26
+ }
27
+
28
+ st.set_page_config(page_title="Sinhala zero-shot classification demo", page_icon=":bar_chart:")
29
+ st.title("Sinhala zero-shot classification demo")
30
+ st.markdown("This is a demo of the zero-shot classification pipeline from the [HuggingFace Transformers library](https://huggingface.co/transformers/).")
31
+ st.markdown("The model used is [Ransaka/sinhala-bert-small](https://huggingface.co/Ransaka/sinhala-bert-small). However, you can use any model from the [HuggingFace model hub](https://huggingface.co/models).")
32
+
33
+ # select model
34
+ def get_model_id():
35
+ st.subheader("Select a model to use")
36
+ model_list = ["Ransaka/sinhala-bert-small","Ransaka/SinhalaRoberta","keshan/SinhalaBERTo"]
37
+ selected_model = st.selectbox("Select Model", model_list)
38
+ st.write(f"Selected model: {selected_model}")
39
+ tokenizer = AutoTokenizer.from_pretrained(selected_model)
40
+ mask_token = tokenizer.mask_token
41
+ return selected_model,mask_token
42
+
43
+ # get input text
44
+ def get_input_text():
45
+ st.subheader("Input a sentence to classify")
46
+ st.write("Remember: Longer sentences may produce better results and take longer to classify😊")
47
+ sentence = st.text_area("Input text", height=300)
48
+ return sentence
49
+
50
+ def show_example():
51
+ examples = [
52
+ """ශ්‍රී ලංකාවේ චීන සංස්කෘතික මධ්‍යස්ථානය සහ නැන්ජින් සංචාරක හා සංස්කෘතික මණ්ඩලය විසින් “ගිම්හාන දිනය” සැමරීම සඳහා පවත්වන ලද සංස්කෘතික උත්සව මාලාවක් පසුගියදා කොළඹ සහ මහනුවර නගරවලදී පැවත්විණි. “ගිම්හාන දිනය” යනු චීන සංස්කෘතිය තුළ “චීන නව වසර” තරමටම වැදගත් සහ ඉතා ඉහළින් සමරනු ලබන වැදගත් දිනයකි. මෙම උත්සව මාලාව සැප්තැම්බර් 22 වැනි දින සිට 25 වැනිදා දක්වා පැවත්විණි.
53
+ කොළඹ චීන සංස්කෘතික මධ්‍යස්ථානයේදී පැවත්වුනු ප්‍රධාන උත්සවය සාම්ප්‍රදායයික චීන සහ ශ්‍රී ලාංකික සංස්කෘතික සංදර්ශන සහ කලා සහ ඡායාරූප ප්‍රදර්ශන, සාම්ප්‍රදායයික චීන තේ පානෝත්සව සමඟින් ඉතා වර්ණවත් අයුරින් පැවත්විණි. එහිදී චීන සංස්කෘතිය තුළ ‘ගිම්හාන දිනයේ’ ඇති වැදගත්කම සහ ඓතිහාසික චීන ශ්‍රී ලාංකික සබඳතාවයන් ගැන හරබර දේශන රැසක්ද ප්‍රකට කථිකයින් විසින් සිදු කරන ලදි.
54
+ """,
55
+ """මාලදිවයිනේ පැවති ජනාධිපතිවරණයෙන් චීන හිතවාදී අපේක්ෂක 45 හැවිරිදි මොහොමඩ් මුයිසු ජනාධිපති ධුරයට පත් වී තිබේ.
56
+
57
+ ඉන්දියාව සමඟ සබඳතා ශක්තිමත් කළ වත්මන් ජනාධිපති ඊබ්‍රාහිම් මොහොමඩ් සෝලිහ් පරාජයට පත් කරමින් මොහොමඩ් මුයිසු ජනාධිපතිවරණය ජයග්‍රහණය කර ඇත.
58
+
59
+ මොහොමඩ් මුයිසු 54%ක ඡන්ද ප්‍රතිශතයකින් ජනාධිපතිවරණය ජයග්‍රහණය කර තිබේ.
60
+
61
+ 'ඉන්දියාව ඉවතට' යන සටන් පාඨය ඔස්සේ මොහොමඩ් මුයිසු සිය ජනාධිපතිවරණ ව්‍යාපාරය සිදු කළේය.
62
+ """,
63
+ """ආසියානු ක්‍රීඩා උළෙලේ කාන්තා ක්‍රිකට් අවසාන තරගයේ කාසියේ වාසිය දිනාගැනීමට ඉන්දීය නායිකාව සමත්වුණි.
64
+
65
+ ඒ අනුව ඇය පළමුවෙන් පන්දුවට පහරදීමට තීරණය කළාය.
66
+
67
+ තරගය මෙරට වේලාවෙන් පෙරවරු 11.30ට චීනයේ හැන්ග්ෂු හිදී ආරම්භ වීමට නියමිතය.
68
+
69
+ ඊයේ (24) පැවති දෙවන අවසන් පූර්ව තරගයෙන් පාකිස්තාන කණ්ඩායම පරදා කඩුලු 06ක ජයක් හිමිකරගනිමින් රන් පදක්කම සඳහා වූ අවසන් තරගයට සුදුසුකම් ලබාගැනීමට ශ්‍රී ලංකා කාන්තා කණ්ඩායම සමත්වුණි.
70
+
71
+ අද පැවැත්වෙන තරගයෙන් ජයගතහොත් ශ්‍රී ලංකා කණ්ඩායමට රන් පදක්කම හිමිවන අතර පරාජය වුවහොත් තරග ඉසව්වේ රිදී පදක්කම හිමි වේ.
72
+
73
+ ඒ අනුව 2014 වසරට පසුව එනම් වසර 9කට පසුව ආසියානු ක්‍රීඩා උළෙලකදී ශ්‍රී ලංකාවට පදක්කමක් හිමිවීමට නියමිතය."""
74
+ ]
75
+ st.subheader("Examples")
76
+ st.table(pd.DataFrame(examples, columns=['Example']))
77
+
78
+ # get prompt
79
+ def get_prompt(mask_token):
80
+ st.subheader("Input a prompt")
81
+ # user can toggle between default prompt and custom prompt
82
+ default_prompt = st.checkbox("Use default prompt",value=True)
83
+ if default_prompt:
84
+ prompt = f"මෙය {mask_token} ඝණයේ තොරතුරක්."
85
+ else:
86
+ prompt = st.text_input("Prompt", f"මෙය {mask_token} ඝණයේ තොරතුරක්.")
87
+ return prompt
88
+
89
+ if __name__ == "__main__":
90
+ model_id,mask_token = get_model_id()
91
+ if st.checkbox("Show example"):
92
+ show_example()
93
+ sentence = get_input_text()
94
+ # submit button
95
+ if sentence:
96
+ prompt = get_prompt(mask_token)
97
+ if prompt and st.button("Classify"):
98
+ pipe = pipeline("fill-mask", model=model_id)
99
+ output = pipe(sentence + prompt, targets=TARGETS, top_k =len(TARGETS))
100
+ output = pd.DataFrame(output)
101
+ output['score'] = output['score'].apply(lambda x:x/sum(output['score']))
102
+ output.rename(columns={'token_str':'label'}, inplace=True)
103
+ # plot altair bar chart
104
+ bar_chart = alt.Chart(output).mark_bar().encode(
105
+ x='label:N',
106
+ y='score:Q',
107
+ # increase blue gradient as score increases
108
+ color=alt.Color('score:Q', scale=alt.Scale(scheme='blues')),
109
+ tooltip=['label:N', 'score:Q']
110
+ ).properties(
111
+ title='Zeroshot Classification Results',
112
+ width=800,
113
+ height=400
114
+ )
115
+ bar_chart.configure_axis(grid=False, labelFont=font, labelColor=font_color, titleColor= font_title).configure_view(strokeOpacity=0)
116
+ bar_chart.configure_title(anchor='start')
117
+ predicted_class = output.loc[output['score'].idxmax()]['label']
118
+ predicted_class_en = SIN_2_ENG[predicted_class]
119
+ st.altair_chart(bar_chart, use_container_width=True)
120
+ st.markdown(
121
+ "It seems this sentence belongs to the :green[{}]({}) category.".format(predicted_class,predicted_class_en)
122
+ )
123
+ # st.markdown(
124
+ # "This demo was created by [Ransaka Ravihara](https://www.linkedin.com/in/ransaka/)."
125
+ # )