File size: 4,676 Bytes
db0f499
ceabca1
 
db0f499
8ed0852
 
ceabca1
c31a852
15a7058
0f21896
 
8ed0852
db0f499
6a0b78f
 
 
 
 
 
 
 
ceabca1
 
 
 
 
 
 
 
 
4dfc57d
6a0b78f
 
ceabca1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a0b78f
ceabca1
6a0b78f
 
ceabca1
6a0b78f
09e7c03
db0f499
 
 
 
8ed0852
 
 
 
ceabca1
8ed0852
 
 
ceabca1
 
8ed0852
 
ceabca1
 
 
 
 
 
db0f499
 
 
6a0b78f
ceabca1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import torch

# Define model names
models = {
    "gte base (gender v3.1)": "breadlicker45/gte-gender-v3.1-test",
    "ModernBERT Large (gender v3)": "breadlicker45/modernbert-gender-v3-test",
    "ModernBERT Large (gender v2)": "breadlicker45/modernbert-gender-v2",
    "ModernBERT Base (gender)": "breadlicker45/ModernBERT-base-gender",
    "ModernBERT Large (gender)": "breadlicker45/ModernBERT-large-gender"
}

# Define the mapping for user-friendly labels
label_map = {
    "LABEL_0": "Male (0)",
    "0": "Male (0)",
    "LABEL_1": "Female (1)",
    "1": "Female (1)"
}

# A cache to store loaded models/pipelines to speed up subsequent requests
model_cache = {}

# Determine the device to run on (GPU if available, otherwise CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


# The main classification function, now handles both model types
def classify_text(model_name, text):
    try:
        processed_results = {}
        model_id = models[model_name]

        # --- SPECIAL HANDLING FOR THE GTE MODEL ---
        if "gte-gender" in model_id:
            # Check if model/tokenizer is already in our cache
            if model_id not in model_cache:
                print(f"Loading GTE model and tokenizer manually: {model_id}...")
                tokenizer = AutoTokenizer.from_pretrained(model_id)
                model = AutoModelForSequenceClassification.from_pretrained(model_id, trust_remote_code=True).to(device)
                model_cache[model_id] = (model, tokenizer) # Cache both
            
            model, tokenizer = model_cache[model_id]

            # Tokenize the input text and move to the correct device
            inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)

            # Get model predictions
            with torch.no_grad():
                logits = model(**inputs).logits
            
            # Convert logits to probabilities using softmax
            probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
            
            # Format results to match the pipeline's output style
            processed_results[label_map["LABEL_0"]] = probabilities[0].item()
            processed_results[label_map["LABEL_1"]] = probabilities[1].item()

        # --- STANDARD HANDLING FOR PIPELINE-COMPATIBLE MODELS ---
        else:
            # Check if the pipeline is already in our cache
            if model_id not in model_cache:
                print(f"Loading pipeline for model: {model_id}...")
                # Load and cache the pipeline
                model_cache[model_id] = pipeline(
                    "text-classification", 
                    model=model_id, 
                    top_k=None,
                    device=device # Use the determined device
                )
            
            classifier = model_cache[model_id]
            predictions = classifier(text)

            # Process predictions to use friendly labels
            if predictions and isinstance(predictions, list) and predictions[0]:
                for pred in predictions[0]:
                    raw_label = pred["label"]
                    score = pred["score"]
                    friendly_label = label_map.get(raw_label, raw_label)
                    processed_results[friendly_label] = score
        
        return processed_results

    except Exception as e:
        print(f"Error: {e}")
        # Return an error message suitable for gr.Label or gr.JSON
        return {"Error": f"Failed to process: {e}"}


# Create the Gradio interface
interface = gr.Interface(
    fn=classify_text,
    inputs=[
        gr.Dropdown(
            list(models.keys()),
            label="Select Model",
            value="gte base (gender v3.1)" # Default model
        ),
        gr.Textbox(
            lines=2,
            placeholder="Enter text to classify for perceived gender...",
            value="This is an example sentence."
        )
    ],
    # Since we now consistently return a dictionary of {label: score},
    # we can go back to using the nicer-looking gr.Label component!
    outputs=gr.Label(num_top_classes=2, label="Classification Results"),
    title="ModernBERT & GTE Gender Classifier",
    description="Select a model and enter a sentence to see the perceived gender classification (Male=0, Female=1) and confidence scores. Note: Text-based gender classification can be unreliable and reflect societal biases.",
    allow_flagging="never",
)

# Launch the app
if __name__ == "__main__":
    interface.launch()