File size: 3,757 Bytes
28f34f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import nltk
from nltk.tokenize import sent_tokenize

# Set page config at the very beginning
st.set_page_config(page_title="LLM Detector", layout="centered")


# Download the punkt tokenizer for sentence splitting (with caching)
@st.cache_resource
def download_nltk_punkt():
    nltk.download("punkt", quiet=True)


download_nltk_punkt()


# Load the model and tokenizer (with caching)
@st.cache_resource
def load_model_and_tokenizer():
    model_name = "CoolSpring/creative-writing-llm-detector-deberta-v3-xsmall"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    return tokenizer, model


tokenizer, model = load_model_and_tokenizer()


def classify_text(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        logits = model(**inputs).logits
    probabilities = torch.softmax(logits, dim=1)
    return probabilities[0][1].item()  # Probability of being AI-generated


def highlight_suspicious_sentences(text):
    sentences = sent_tokenize(text)
    scores = [classify_text(sentence) for sentence in sentences]
    return sentences, scores


def get_color(score):
    if score < 0.33:
        return "rgba(144, 238, 144, 0.3)"  # Light green
    elif score < 0.66:
        return "rgba(255, 255, 0, 0.3)"  # Light yellow
    else:
        return "rgba(255, 99, 71, 0.3)"  # Light red


st.title("🤖 LLM Detector")
st.write("Enter text to detect if it's written by an AI language model.")

# Use session state to store the input text
if "text_input" not in st.session_state:
    st.session_state.text_input = ""

text_input = st.text_area(
    "Enter your text here:", value=st.session_state.text_input, height=200
)

# Update session state when input changes
if text_input != st.session_state.text_input:
    st.session_state.text_input = text_input

if st.button("Analyze and Highlight"):
    if text_input:
        overall_probability = classify_text(text_input)
        st.markdown(
            f"<h3>Overall probability of being AI-generated: <span style='color: {'red' if overall_probability > 0.5 else 'green'};'>{overall_probability:.2%}</span></h3>",
            unsafe_allow_html=True,
        )

        st.markdown("### Sentence-level analysis:")
        sentences, scores = highlight_suspicious_sentences(text_input)

        for sentence, score in zip(sentences, scores):
            color = get_color(score)
            st.markdown(
                f"<div style='background-color: {color}; padding: 10px; margin: 5px 0; border-radius: 5px;'><strong>{score:.2%}</strong> - {sentence}</div>",
                unsafe_allow_html=True,
            )
    else:
        st.warning("Please enter some text to analyze.")

how_it_works_text = """This LLM Detector uses [CoolSpring/creative-writing-llm-detector-deberta-v3-xsmall](https://huggingface.co/CoolSpring/creative-writing-llm-detector-deberta-v3-xsmall), a DeBERTa-v3-xsmall model fine-tuned for text classification.

It analyzes the input text and estimates the probability of it being generated by an AI language model.

The sentence-level analysis breaks down the input into individual sentences and analyzes each one separately, allowing you to see which parts of the text are more likely to be AI-generated.

Please note that this is not 100% accurate and should be used as a guide rather than a definitive measure."""

if st.button("Fill with Sample Text"):
    st.session_state.text_input = "\n".join(how_it_works_text.splitlines()[2:])
    st.rerun()

st.markdown(
    f"""### How it works
{how_it_works_text}"""
)