LLM_Detector / app.py
CoolSpring's picture
Introduce MarkupSafe
7fa06e2 unverified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import nltk
from nltk.tokenize import sent_tokenize
from markupsafe import escape
# Set page config at the very beginning
st.set_page_config(page_title="LLM Detector", layout="centered")
# Download the punkt tokenizer for sentence splitting (with caching)
@st.cache_resource
def download_nltk_punkt():
nltk.download("punkt", quiet=True)
download_nltk_punkt()
# Load the model and tokenizer (with caching)
@st.cache_resource
def load_model_and_tokenizer():
model_name = "CoolSpring/creative-writing-llm-detector-deberta-v3-xsmall"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
return tokenizer, model
tokenizer, model = load_model_and_tokenizer()
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
logits = model(**inputs).logits
probabilities = torch.softmax(logits, dim=1)
return probabilities[0][1].item() # Probability of being AI-generated
def highlight_suspicious_sentences(text):
sentences = sent_tokenize(text)
inputs = tokenizer(
sentences, return_tensors="pt", truncation=True, max_length=512, padding=True
)
with torch.no_grad():
logits = model(**inputs).logits
probabilities = torch.softmax(logits, dim=1)
scores = probabilities[
:, 1
].tolist() # Probability of being AI-generated for each sentence
return sentences, scores
def get_color(score):
if score < 0.33:
return "rgba(144, 238, 144, 0.3)" # Light green
elif score < 0.66:
return "rgba(255, 255, 0, 0.3)" # Light yellow
else:
return "rgba(255, 99, 71, 0.3)" # Light red
st.title("πŸ€– LLM Detector")
st.write("Enter text to detect if it's written by an AI language model.")
# Use session state to store the input text
if "text_input" not in st.session_state:
st.session_state.text_input = ""
text_input = st.text_area(
"Enter your text here:", value=st.session_state.text_input, height=200
)
# Update session state when input changes
if text_input != st.session_state.text_input:
st.session_state.text_input = text_input
if st.button("Analyze and Highlight"):
if text_input:
with st.spinner("Analyzing text..."):
overall_probability = classify_text(text_input)
st.html(
f"<h3>Overall probability of being AI-generated: <span style='color: {'red' if overall_probability > 0.5 else 'green'};'>{overall_probability:.2%}</span></h3>",
)
st.markdown("### Sentence-level analysis:")
sentences, scores = highlight_suspicious_sentences(text_input)
for sentence, score in zip(sentences, scores):
color = get_color(score)
st.html(
f"<div style='background-color: {color}; padding: 10px; margin: 5px 0; border-radius: 5px;'><strong>{score:.2%}</strong> - {escape(sentence)}</div>",
)
else:
st.warning("Please enter some text to analyze.")
how_it_works_text = """This LLM Detector uses [CoolSpring/creative-writing-llm-detector-deberta-v3-xsmall](https://huggingface.co/CoolSpring/creative-writing-llm-detector-deberta-v3-xsmall), a DeBERTa-v3-xsmall model fine-tuned for text classification.
It analyzes the input text and estimates the probability of it being generated by an AI language model.
The sentence-level analysis breaks down the input into individual sentences and analyzes each one separately, allowing you to see which parts of the text are more likely to be AI-generated.
Please note that this is not 100% accurate and should be used as a guide rather than a definitive measure."""
if st.button("Fill with Sample Text"):
st.session_state.text_input = "\n".join(how_it_works_text.splitlines()[2:])
st.rerun()
st.markdown(
f"""### How it works
{how_it_works_text}"""
)