joshnguyen commited on
Commit
fc85928
·
1 Parent(s): d6df692

First commit

Browse files
Files changed (2) hide show
  1. app.py +61 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import (
4
+ AutoModelForSequenceClassification,
5
+ AutoTokenizer,
6
+ )
7
+ from typing import Dict
8
+
9
+ FOUNDATIONS = ["authority", "care", "fairness", "loyalty", "sanctity"]
10
+ tokenizer = AutoTokenizer.from_pretrained("joshnguyen/mformer-authority")
11
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ MODELS = {}
13
+ for foundation in FOUNDATIONS:
14
+ model = AutoModelForSequenceClassification.from_pretrained(
15
+ pretrained_model_name_or_path=f"joshnguyen/mformer-{foundation}",
16
+ )
17
+ MODELS[foundation] = model.to(DEVICE)
18
+
19
+
20
+ def classify_text(text: str) -> Dict[str, float]:
21
+ # Encode the prompt
22
+ inputs = tokenizer([text],
23
+ padding=True,
24
+ truncation=True,
25
+ return_tensors='pt').to(DEVICE)
26
+ scores = {}
27
+ for foundation in FOUNDATIONS:
28
+ model = MODELS[foundation]
29
+ outputs = model(**inputs)
30
+ outputs = torch.softmax(outputs.logits, dim=1)
31
+ outputs = outputs[:, 1]
32
+ score = outputs.detach().cpu().numpy()[0]
33
+ scores[foundation] = score
34
+ return scores
35
+
36
+
37
+ demo = gr.Interface(
38
+ fn=classify_text,
39
+ inputs=[
40
+ # Prompt
41
+ gr.Textbox(
42
+ label="Input text",
43
+ container=False,
44
+ show_label=True,
45
+ placeholder="Enter some text...",
46
+ lines=10,
47
+ scale=10,
48
+ ),
49
+ ],
50
+ outputs=[
51
+ gr.Label(
52
+ label="Moral foundations scores",
53
+ container=False,
54
+ show_label=True,
55
+ scale=10,
56
+ lines=10,
57
+ )
58
+ ],
59
+ )
60
+
61
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==3.37.0
2
+ protobuf==3.20.3
3
+ scipy==1.11.1
4
+ sentencepiece==0.1.99
5
+ torch==2.0.1
6
+ transformers==4.31.0