jer233 commited on
Commit
a236161
·
verified ·
1 Parent(s): b290274

Update demo.py

Browse files
Files changed (1) hide show
  1. demo.py +3 -70
demo.py CHANGED
@@ -1,56 +1,9 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModel
3
- from utils_MMD import extract_features # Adjust the import path
4
- from MMD_calculate import mmd_two_sample_baseline # Adjust the import path
5
-
6
- MINIMUM_TOKENS = 64
7
- THRESHOLD = 0.5 # Threshold for classification
8
-
9
- def count_tokens(text, tokenizer):
10
- """
11
- Counts the number of tokens in the text using the provided tokenizer.
12
- """
13
- return len(tokenizer(text).input_ids)
14
 
 
15
  def run_test_power(model_name, real_text, generated_text, N=10):
16
- """
17
- Runs the test power calculation for provided real and generated texts.
18
-
19
- Args:
20
- model_name (str): Hugging Face model name.
21
- real_text (str): Example real text for comparison.
22
- generated_text (str): The input text to classify.
23
- N (int): Number of repetitions for MMD calculation.
24
-
25
- Returns:
26
- str: "Prediction: Human" or "Prediction: AI".
27
- """
28
- # Load tokenizer and model
29
- tokenizer = AutoTokenizer.from_pretrained(model_name)
30
- model = AutoModel.from_pretrained(model_name).cuda()
31
- model.eval()
32
-
33
- # Ensure minimum token length
34
- if count_tokens(real_text, tokenizer) < MINIMUM_TOKENS or count_tokens(generated_text, tokenizer) < MINIMUM_TOKENS:
35
- return "Too short length. Need a minimum of 64 tokens to calculate Test Power."
36
-
37
- # Extract features
38
- fea_real_ls = extract_features([real_text], tokenizer, model)
39
- fea_generated_ls = extract_features([generated_text], tokenizer, model)
40
-
41
- # Calculate test power list
42
- test_power_ls = mmd_two_sample_baseline(fea_real_ls, fea_generated_ls, N=N)
43
-
44
- # Compute the average test power value
45
- power_test_value = sum(test_power_ls) / len(test_power_ls)
46
-
47
- # Classify the text
48
- if power_test_value < THRESHOLD:
49
- return "Prediction: Human"
50
- else:
51
- return "Prediction: AI"
52
 
53
- # CSS for custom styling
54
  css = """
55
  #header { text-align: center; font-size: 1.5em; margin-bottom: 20px; }
56
  #output-text { font-weight: bold; font-size: 1.2em; }
@@ -95,28 +48,8 @@ with gr.Blocks(css=css) as app:
95
  placeholder="Prediction: Human or AI",
96
  elem_id="output-text",
97
  )
98
- with gr.Accordion("Disclaimer", open=False):
99
- gr.Markdown(
100
- """
101
- - **Disclaimer**: This tool is for demonstration purposes only. It is not a foolproof AI detector.
102
- - **Accuracy**: Results may vary based on input length and quality.
103
- """
104
- )
105
- with gr.Accordion("Citations", open=False):
106
- gr.Markdown(
107
- """
108
- ```
109
- @inproceedings{zhangs2024MMDMP,
110
- title={Detecting Machine-Generated Texts by Multi-Population Aware Optimization for Maximum Mean Discrepancy},
111
- author={Zhang, Shuhai and Song, Yiliao and Yang, Jiahao and Li, Yuanqing and Han, Bo and Tan, Mingkui},
112
- booktitle = {International Conference on Learning Representations (ICLR)},
113
- year={2024}
114
- }
115
- ```
116
- """
117
- )
118
  submit_button.click(
119
- run_test_power, inputs=[model_name, "The cat sat on the mat.", input_text], outputs=output
120
  )
121
  clear_button.click(lambda: ("", ""), inputs=[], outputs=[input_text, output])
122
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ # Mock function for testing layout
4
  def run_test_power(model_name, real_text, generated_text, N=10):
5
+ return "Prediction: Human (Mocked)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
7
  css = """
8
  #header { text-align: center; font-size: 1.5em; margin-bottom: 20px; }
9
  #output-text { font-weight: bold; font-size: 1.2em; }
 
48
  placeholder="Prediction: Human or AI",
49
  elem_id="output-text",
50
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  submit_button.click(
52
+ run_test_power, inputs=[model_name, "Example real text", input_text], outputs=output
53
  )
54
  clear_button.click(lambda: ("", ""), inputs=[], outputs=[input_text, output])
55