File size: 1,866 Bytes
fc85928
 
 
 
 
 
 
126c386
508ada6
126c386
 
fc85928
 
126c386
508ada6
126c386
508ada6
fc85928
 
 
 
 
126c386
fc85928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508ada6
fc85928
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508ada6
fc85928
 
508ada6
fc85928
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import gradio as gr
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
from typing import Dict
import os
from custom_label import CustomLabel
from huggingface_hub import login
login(token=os.getenv("HUGGINGFACE_TOKEN"))

FOUNDATIONS = ["authority", "care", "fairness", "loyalty", "sanctity"]
tokenizer = AutoTokenizer.from_pretrained(
    "joshnguyen/mformer-authority",
    use_auth_token=True
)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODELS = {}
for foundation in FOUNDATIONS:
    model = AutoModelForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path=f"joshnguyen/mformer-{foundation}",
        use_auth_token=True
    )
    MODELS[foundation] = model.to(DEVICE)


def classify_text(text: str) -> Dict[str, float]:
    # Encode the prompt
    inputs = tokenizer([text],
                       padding=True,
                       truncation=True,
                       return_tensors='pt').to(DEVICE)
    scores = {}
    for foundation in FOUNDATIONS:
        model = MODELS[foundation]
        outputs = model(**inputs)
        outputs = torch.softmax(outputs.logits, dim=1)
        outputs = outputs[:, 1]
        score = outputs.detach().cpu().numpy()[0]
        scores[foundation.capitalize()] = score
    return scores


demo = gr.Interface(
    fn=classify_text,
    inputs=[
        # Prompt
        gr.Textbox(
            label="Input text",
            container=False,
            show_label=True,
            placeholder="Enter some text...",
            lines=10,
            scale=10,
        ),
    ],
    outputs=[
        CustomLabel(
            label="Moral foundations scores",
            container=False,
            show_label=False,
            scale=10,
            lines=10,
        )
    ],
)

demo.queue(max_size=20).launch()