Spaces:
Running
Running
File size: 4,676 Bytes
db0f499 ceabca1 db0f499 8ed0852 ceabca1 c31a852 15a7058 0f21896 8ed0852 db0f499 6a0b78f ceabca1 4dfc57d 6a0b78f ceabca1 6a0b78f ceabca1 6a0b78f ceabca1 6a0b78f 09e7c03 db0f499 8ed0852 ceabca1 8ed0852 ceabca1 8ed0852 ceabca1 db0f499 6a0b78f ceabca1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torch
# Define model names
models = {
"gte base (gender v3.1)": "breadlicker45/gte-gender-v3.1-test",
"ModernBERT Large (gender v3)": "breadlicker45/modernbert-gender-v3-test",
"ModernBERT Large (gender v2)": "breadlicker45/modernbert-gender-v2",
"ModernBERT Base (gender)": "breadlicker45/ModernBERT-base-gender",
"ModernBERT Large (gender)": "breadlicker45/ModernBERT-large-gender"
}
# Define the mapping for user-friendly labels
label_map = {
"LABEL_0": "Male (0)",
"0": "Male (0)",
"LABEL_1": "Female (1)",
"1": "Female (1)"
}
# A cache to store loaded models/pipelines to speed up subsequent requests
model_cache = {}
# Determine the device to run on (GPU if available, otherwise CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# The main classification function, now handles both model types
def classify_text(model_name, text):
try:
processed_results = {}
model_id = models[model_name]
# --- SPECIAL HANDLING FOR THE GTE MODEL ---
if "gte-gender" in model_id:
# Check if model/tokenizer is already in our cache
if model_id not in model_cache:
print(f"Loading GTE model and tokenizer manually: {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, trust_remote_code=True).to(device)
model_cache[model_id] = (model, tokenizer) # Cache both
model, tokenizer = model_cache[model_id]
# Tokenize the input text and move to the correct device
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
# Get model predictions
with torch.no_grad():
logits = model(**inputs).logits
# Convert logits to probabilities using softmax
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
# Format results to match the pipeline's output style
processed_results[label_map["LABEL_0"]] = probabilities[0].item()
processed_results[label_map["LABEL_1"]] = probabilities[1].item()
# --- STANDARD HANDLING FOR PIPELINE-COMPATIBLE MODELS ---
else:
# Check if the pipeline is already in our cache
if model_id not in model_cache:
print(f"Loading pipeline for model: {model_id}...")
# Load and cache the pipeline
model_cache[model_id] = pipeline(
"text-classification",
model=model_id,
top_k=None,
device=device # Use the determined device
)
classifier = model_cache[model_id]
predictions = classifier(text)
# Process predictions to use friendly labels
if predictions and isinstance(predictions, list) and predictions[0]:
for pred in predictions[0]:
raw_label = pred["label"]
score = pred["score"]
friendly_label = label_map.get(raw_label, raw_label)
processed_results[friendly_label] = score
return processed_results
except Exception as e:
print(f"Error: {e}")
# Return an error message suitable for gr.Label or gr.JSON
return {"Error": f"Failed to process: {e}"}
# Create the Gradio interface
interface = gr.Interface(
fn=classify_text,
inputs=[
gr.Dropdown(
list(models.keys()),
label="Select Model",
value="gte base (gender v3.1)" # Default model
),
gr.Textbox(
lines=2,
placeholder="Enter text to classify for perceived gender...",
value="This is an example sentence."
)
],
# Since we now consistently return a dictionary of {label: score},
# we can go back to using the nicer-looking gr.Label component!
outputs=gr.Label(num_top_classes=2, label="Classification Results"),
title="ModernBERT & GTE Gender Classifier",
description="Select a model and enter a sentence to see the perceived gender classification (Male=0, Female=1) and confidence scores. Note: Text-based gender classification can be unreliable and reflect societal biases.",
allow_flagging="never",
)
# Launch the app
if __name__ == "__main__":
interface.launch() |