Spaces:
Runtime error
Runtime error
shubh2014shiv
commited on
Commit
•
18f8de6
1
Parent(s):
62cffe7
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import plotly.express as px
|
4 |
+
import plotly.graph_objects as go
|
5 |
+
from st_aggrid import AgGrid
|
6 |
+
from st_aggrid.grid_options_builder import GridOptionsBuilder
|
7 |
+
from st_aggrid.shared import JsCode
|
8 |
+
from st_aggrid.shared import GridUpdateMode
|
9 |
+
from transformers import T5Tokenizer, BertForSequenceClassification
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
st.set_page_config(layout="wide")
|
14 |
+
st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers")
|
15 |
+
st.sidebar.subheader("自然言語処理 トピック")
|
16 |
+
topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis"])
|
17 |
+
|
18 |
+
st.write("-" * 5)
|
19 |
+
jp_review_text = None
|
20 |
+
#JAPANESE_SENTIMENT_PROJECT_PATH = './Japanese Amazon reviews sentiments/'
|
21 |
+
|
22 |
+
if topic == "Sentiment Analysis":
|
23 |
+
st.markdown(
|
24 |
+
"<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Transfer Learning based Japanese Sentiments Analysis using BERT<b></h2>",
|
25 |
+
unsafe_allow_html=True)
|
26 |
+
st.markdown(
|
27 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese Amazon Reviews Data (日本のAmazonレビューデータ)<b></h3>",
|
28 |
+
unsafe_allow_html=True)
|
29 |
+
|
30 |
+
amazon_jp_reviews = pd.read_csv("review_val.csv").sample(frac=1,random_state=10).iloc[:16000]
|
31 |
+
|
32 |
+
cellstyle_jscode = JsCode(
|
33 |
+
"""
|
34 |
+
function(params) {
|
35 |
+
if (params.value.includes('positive')) {
|
36 |
+
return {
|
37 |
+
'color': 'black',
|
38 |
+
'backgroundColor': '#32CD32'
|
39 |
+
}
|
40 |
+
} else {
|
41 |
+
return {
|
42 |
+
'color': 'black',
|
43 |
+
'backgroundColor': '#FF7F7F'
|
44 |
+
}
|
45 |
+
}
|
46 |
+
};
|
47 |
+
"""
|
48 |
+
)
|
49 |
+
st.write('<style>div.row-widget.stRadio > div{flex-direction:row;justify-content: center;} </style>',
|
50 |
+
unsafe_allow_html=True)
|
51 |
+
|
52 |
+
st.write('<style>div.st-bf{flex-direction:column;} div.st-ag{font-weight:bold;padding-left:2px;}</style>',
|
53 |
+
unsafe_allow_html=True)
|
54 |
+
|
55 |
+
choose = st.radio("", ("Choose a review from the dataframe below", "Manually write review"))
|
56 |
+
|
57 |
+
SELECT_ONE_REVIEW = "Choose a review from the dataframe below"
|
58 |
+
WRITE_REVIEW = "Manually write review"
|
59 |
+
|
60 |
+
gb = GridOptionsBuilder.from_dataframe(amazon_jp_reviews)
|
61 |
+
gb.configure_column("sentiment", cellStyle=cellstyle_jscode)
|
62 |
+
gb.configure_pagination()
|
63 |
+
if choose == SELECT_ONE_REVIEW:
|
64 |
+
gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
|
65 |
+
gridOptions = gb.build()
|
66 |
+
|
67 |
+
if choose == SELECT_ONE_REVIEW:
|
68 |
+
jp_review_choice = AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
|
69 |
+
enable_enterprise_modules=True,
|
70 |
+
allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
|
71 |
+
st.info("Select any one the Japanese Reviews by clicking the checkbox. Reviews can be navigated from each page.")
|
72 |
+
if len(jp_review_choice['selected_rows']) != 0:
|
73 |
+
jp_review_text = jp_review_choice['selected_rows'][0]['review']
|
74 |
+
st.markdown(
|
75 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Selected Review in JSON (JSONで選択されたレビュー)<b></h3>",
|
76 |
+
unsafe_allow_html=True)
|
77 |
+
st.write(jp_review_choice['selected_rows'])
|
78 |
+
|
79 |
+
if choose == WRITE_REVIEW:
|
80 |
+
|
81 |
+
AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
|
82 |
+
enable_enterprise_modules=True,
|
83 |
+
allow_unsafe_jscode=True)
|
84 |
+
with open("test_reviews_jp.csv", "rb") as file:
|
85 |
+
st.download_button(label="Download Additional Japanese Reviews", data=file,
|
86 |
+
file_name="Additional Japanese Reviews.csv")
|
87 |
+
st.info("Additional subset of Japanese Reviews can be downloaded and any review can be copied & pasted in text area.")
|
88 |
+
sample_japanese_review_input = "子供のレッスンバッグ用に購入。 思ったより大きく、ピアノ教本を入れるには充分でした。中は汚れてました。 何より驚いたのは、商品の梱包。 2つ折は許せるが、透明ビニール袋の底思いっきり空いてますけど? 何これ?包むっていうか挟んで終わり?底が全開している。 引っ張れば誰でも中身の注文書も、商品も見れる状態って何なの? 個人情報が晒されて、商品も粗末な扱いで嫌な気持ちでした。 郵送で中身が無事のが奇跡じゃないでしょうか? ありえない"
|
89 |
+
jp_review_text = st.text_area(label="Press 'Ctrl+Enter' after writing review in below text area",
|
90 |
+
value=sample_japanese_review_input)
|
91 |
+
if len(jp_review_text) == 0:
|
92 |
+
st.error("Input text cannot empty. Either write the japanese review in text area manually or select the review from the grid.")
|
93 |
+
|
94 |
+
if jp_review_text:
|
95 |
+
st.markdown(
|
96 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Sentence-Piece based Japanese Tokenizer using RoBERTA<b></h3>",
|
97 |
+
unsafe_allow_html=True)
|
98 |
+
tokens_column, tokenID_column = st.columns(2)
|
99 |
+
tokenizer = T5Tokenizer.from_pretrained('rinna/japanese-roberta-base')
|
100 |
+
tokens = tokenizer.tokenize(jp_review_text)
|
101 |
+
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
102 |
+
with tokens_column:
|
103 |
+
token_expander = st.expander("Expand to see the tokens", expanded=False)
|
104 |
+
with token_expander:
|
105 |
+
st.write(tokens)
|
106 |
+
with tokenID_column:
|
107 |
+
tokenID_expander = st.expander("Expand to see the token IDs", expanded=False)
|
108 |
+
with tokenID_expander:
|
109 |
+
st.write(token_ids)
|
110 |
+
|
111 |
+
st.markdown(
|
112 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Encoded Japanese Review Text to get Input IDs and attention masks as PyTorch Tensor<b></h3>",
|
113 |
+
unsafe_allow_html=True)
|
114 |
+
encoded_data = tokenizer.batch_encode_plus(np.array([jp_review_text]).astype('object'),
|
115 |
+
add_special_tokens=True,
|
116 |
+
return_attention_mask=True,
|
117 |
+
padding=True,
|
118 |
+
max_length=200,
|
119 |
+
return_tensors='pt',
|
120 |
+
truncation=True)
|
121 |
+
input_ids = encoded_data['input_ids']
|
122 |
+
attention_masks = encoded_data['attention_mask']
|
123 |
+
input_ids_column, attention_masks_column = st.columns(2)
|
124 |
+
with input_ids_column:
|
125 |
+
input_ids_expander = st.expander("Expand to see the input IDs tensor")
|
126 |
+
with input_ids_expander:
|
127 |
+
st.write(input_ids)
|
128 |
+
with attention_masks_column:
|
129 |
+
attention_masks_expander = st.expander("Expand to see the attention mask tensor")
|
130 |
+
with attention_masks_expander:
|
131 |
+
st.write(attention_masks)
|
132 |
+
|
133 |
+
st.markdown(
|
134 |
+
"<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Predict Sentiment of review using Fine-Tuned Japanese BERT<b></h3>",
|
135 |
+
unsafe_allow_html=True)
|
136 |
+
|
137 |
+
label_dict = {'positive': 1, 'negative': 0}
|
138 |
+
if st.button("Predict Sentiment"):
|
139 |
+
with st.spinner("Wait.."):
|
140 |
+
predictions = []
|
141 |
+
model = BertForSequenceClassification.from_pretrained("shubh2014shiv/jp_review_sentiments_amzn",
|
142 |
+
num_labels=len(label_dict),
|
143 |
+
output_attentions=False,
|
144 |
+
output_hidden_states=False)
|
145 |
+
#model.load_state_dict(
|
146 |
+
# torch.load(JAPANESE_SENTIMENT_PROJECT_PATH + 'FineTuneJapaneseBert_AmazonReviewSentiments.pt',
|
147 |
+
# map_location=torch.device('cpu')))
|
148 |
+
|
149 |
+
inputs = {
|
150 |
+
'input_ids': input_ids,
|
151 |
+
'attention_mask': attention_masks
|
152 |
+
}
|
153 |
+
|
154 |
+
with torch.no_grad():
|
155 |
+
outputs = model(**inputs)
|
156 |
+
|
157 |
+
logits = outputs.logits
|
158 |
+
logits = logits.detach().cpu().numpy()
|
159 |
+
scores = 1 / (1 + np.exp(-1 * logits))
|
160 |
+
|
161 |
+
result = {"TEXT": jp_review_text,'NEGATIVE': scores[0][0], 'POSITIVE': scores[0][1]}
|
162 |
+
|
163 |
+
result_col,graph_col = st.columns(2)
|
164 |
+
with result_col:
|
165 |
+
st.write(result)
|
166 |
+
with graph_col:
|
167 |
+
fig = px.bar(x=['NEGATIVE','POSITIVE'],y=[result['NEGATIVE'],result['POSITIVE']])
|
168 |
+
fig.update_layout(title="Probability distribution of Sentiment for the given text",\
|
169 |
+
yaxis_title="Probability")
|
170 |
+
fig.update_traces(marker_color=['#FF7F7F','#32CD32'])
|
171 |
+
st.plotly_chart(fig)
|
172 |
+
|