import gradio as gr from transformers import AutoTokenizer, AutoModel from utils_MMD import extract_features # Adjust the import path from MMD_calculate import mmd_two_sample_baseline # Adjust the import path MINIMUM_TOKENS = 64 THRESHOLD = 0.5 # Threshold for classification def count_tokens(text, tokenizer): """ Counts the number of tokens in the text using the provided tokenizer. """ return len(tokenizer(text).input_ids) def run_test_power(model_name, real_text, generated_text, N=10): """ Runs the test power calculation for provided real and generated texts. Args: model_name (str): Hugging Face model name. real_text (str): Example real text for comparison. generated_text (str): The input text to classify. N (int): Number of repetitions for MMD calculation. Returns: str: "Prediction: Human" or "Prediction: AI". """ # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name).cuda() model.eval() # Ensure minimum token length if count_tokens(real_text, tokenizer) < MINIMUM_TOKENS or count_tokens(generated_text, tokenizer) < MINIMUM_TOKENS: return "Too short length. Need a minimum of 64 tokens to calculate Test Power." # Extract features fea_real_ls = extract_features([real_text], tokenizer, model) fea_generated_ls = extract_features([generated_text], tokenizer, model) # Calculate test power list test_power_ls = mmd_two_sample_baseline(fea_real_ls, fea_generated_ls, N=N) # Compute the average test power value power_test_value = sum(test_power_ls) / len(test_power_ls) # Classify the text if power_test_value < THRESHOLD: return "Prediction: Human" else: return "Prediction: AI" # CSS for custom styling css = """ #header { text-align: center; font-size: 1.5em; margin-bottom: 20px; } #output-text { font-weight: bold; font-size: 1.2em; } """ # Gradio App with gr.Blocks(css=css) as app: with gr.Row(): gr.HTML('') with gr.Row(): gr.Markdown( """ [Paper](https://openreview.net/forum?id=z9j7wctoGV) | [Code](https://github.com/xLearn-AU/R-Detect) | [Contact](mailto:1730421718@qq.com) """ ) with gr.Row(): input_text = gr.Textbox( label="Input Text", placeholder="Enter the text to check", lines=8, ) with gr.Row(): model_name = gr.Dropdown( [ "gpt2-medium", "gpt2-large", "t5-large", "t5-small", "roberta-base", "roberta-base-openai-detector", "falcon-rw-1b", ], label="Select Model", value="gpt2-medium", ) with gr.Row(): submit_button = gr.Button("Run Detection", variant="primary") clear_button = gr.Button("Clear", variant="secondary") with gr.Row(): output = gr.Textbox( label="Prediction", placeholder="Prediction: Human or AI", elem_id="output-text", ) with gr.Accordion("Disclaimer", open=False): gr.Markdown( """ - **Disclaimer**: This tool is for demonstration purposes only. It is not a foolproof AI detector. - **Accuracy**: Results may vary based on input length and quality. """ ) with gr.Accordion("Citations", open=False): gr.Markdown( """ ``` @inproceedings{zhangs2024MMDMP, title={Detecting Machine-Generated Texts by Multi-Population Aware Optimization for Maximum Mean Discrepancy}, author={Zhang, Shuhai and Song, Yiliao and Yang, Jiahao and Li, Yuanqing and Han, Bo and Tan, Mingkui}, booktitle = {International Conference on Learning Representations (ICLR)}, year={2024} } ``` """ ) submit_button.click( run_test_power, inputs=[model_name, "The cat sat on the mat.", input_text], outputs=output ) clear_button.click(lambda: ("", ""), inputs=[], outputs=[input_text, output]) app.launch()