SJTUCL commited on
Commit
9c53030
·
1 Parent(s): 1dec133

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -38
app.py CHANGED
@@ -6,56 +6,90 @@ import gradio as gr
6
 
7
  from nltk import sent_tokenize
8
  from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- detector = pipeline(task='text-classification', model='yikang0131/argugpt-detector-sent')
11
 
12
  def predict_doc(doc):
13
  sents = sent_tokenize(doc)
14
  data = {'sentence': [], 'label': [], 'score': []}
15
  res = []
16
  for sent in sents:
17
- label, score = predict_one_sent(sent)
18
  data['sentence'].append(sent)
19
- data['score'].append(score)
20
- if label == 'LABEL_0':
21
- res.append((sent, 'Human'))
 
 
 
 
 
 
 
 
 
 
 
22
  data['label'].append('Human')
23
- else:
24
- res.append((sent, 'Machine'))
25
- data['label'].append('Machine')
26
  df = pd.DataFrame(data)
27
  df.to_csv('result.csv')
28
- return res, df, 'result.csv'
 
 
 
 
 
 
29
 
30
 
31
  def predict_one_sent(sent):
 
 
 
 
 
32
  res = detector(sent)[0]
33
- return res['label'], res['score']
34
-
35
-
36
- iface = gr.Interface(
37
- fn=predict_doc,
38
- inputs=[
39
- gr.Textbox(
40
- label='Essay input',
41
- info="Please enter essay in the textbox",
42
- lines=5
43
- )
44
- ],
45
- outputs=[
46
- gr.HighlightedText(
47
- label='Labeled Result',
48
- show_legend=True
49
- ).style(color_map={'Machine': 'red', 'Human': 'green'}),
50
- gr.DataFrame(
51
- label='Table with Probability Score',
52
- max_rows=10
53
- ),
54
- gr.File(
55
- label='CSV file storing data with all sentences'
56
- )
57
- ],
58
- theme=gr.themes.Base()
59
- )
60
-
61
- iface.launch()
 
6
 
7
  from nltk import sent_tokenize
8
  from transformers import pipeline
9
+ from gradio.themes.utils import red, green
10
+
11
+ detector = pipeline(task='text-classification', model='SJTU-CL/RoBERTa-large-ArguGPT-sent')
12
+
13
+ color_map = {
14
+ '0%': green.c400,
15
+ '10%': green.c300,
16
+ '20%': green.c200,
17
+ '30%': green.c100,
18
+ '40%': green.c50,
19
+ '50%': None,
20
+ '60%': red.c50,
21
+ '70%': red.c100,
22
+ '80%': red.c200,
23
+ '90%': red.c300,
24
+ '100%': red.c400
25
+ }
26
 
 
27
 
28
  def predict_doc(doc):
29
  sents = sent_tokenize(doc)
30
  data = {'sentence': [], 'label': [], 'score': []}
31
  res = []
32
  for sent in sents:
33
+ prob = predict_one_sent(sent)
34
  data['sentence'].append(sent)
35
+ data['score'].append(round(prob, 4))
36
+ if prob < 0.1: label = '0%'
37
+ elif prob < 0.2: label = '10%'
38
+ elif prob < 0.3: label = '20%',
39
+ elif prob < 0.4: label = '30%',
40
+ elif prob < 0.5: label = '40%',
41
+ elif prob < 0.6: label = '50%',
42
+ elif prob < 0.7: label = '60%',
43
+ elif prob < 0.8: label = '70%',
44
+ elif prob < 0.9: label = '80%',
45
+ elif prob < 1: label = '90%'
46
+ else: label = '100%'
47
+ res.append((sent, label))
48
+ if prob <= 0.5:
49
  data['label'].append('Human')
50
+ else: data['label'].append('Machine')
 
 
51
  df = pd.DataFrame(data)
52
  df.to_csv('result.csv')
53
+ overall_score = df.score.mean()
54
+ sum_str = ''
55
+ if overall_score <= 0.5: overall_label = 'Human'
56
+ else: overall_label = 'Machine'
57
+ sum_str = f'The essay is probably written by {overall_label}. The probability of being generated by AI is {overall_score}'
58
+ # print(sum_str)
59
+ return sum_str, res, df, 'result.csv'
60
 
61
 
62
  def predict_one_sent(sent):
63
+ '''
64
+ convert to prob
65
+ LABEL_1, 0.66 -> 0.66
66
+ LABEL_0, 0.66 -> 0.34
67
+ '''
68
  res = detector(sent)[0]
69
+ org_label, prob = res['label'], res['score']
70
+ if org_label == 'LABEL_0': prob = 1 - prob
71
+ return prob
72
+
73
+
74
+ with gr.Blocks() as demo:
75
+ text_in = gr.Textbox(
76
+ lines=5,
77
+ label='Essay input',
78
+ info='Please enter the essay in the textbox'
79
+ )
80
+ button = gr.Button('Predict who writes this essay!')
81
+ summary = gr.Textbox(lines=1, label='Result summary')
82
+ sent_res = gr.HighlightedText(
83
+ label = 'Labeled Result'
84
+ ).style(color_map=color_map)
85
+ tab = gr.DataFrame(
86
+ label='Table with Probability Score',
87
+ max_rows=10
88
+ )
89
+ csv_f = gr.File(
90
+ label='CSV file storing data with all sentences.'
91
+ )
92
+ button.click(predict_doc, inputs=[text_in], outputs=[summary, sent_res, tab, csv_f], api_name='predict_doc')
93
+
94
+ demo.launch()
95
+