joshnguyen commited on
Commit
37dac63
·
1 Parent(s): ee4b3bb

Change label to bar plot

Browse files
Files changed (2) hide show
  1. app.py +11 -9
  2. requirements.txt +1 -0
app.py CHANGED
@@ -6,7 +6,7 @@ from transformers import (
6
  )
7
  from typing import Dict
8
  import os
9
- from custom_label import CustomLabel
10
  from huggingface_hub import login
11
  login(token=os.getenv("HUGGINGFACE_TOKEN"))
12
 
@@ -31,14 +31,15 @@ def classify_text(text: str) -> Dict[str, float]:
31
  padding=True,
32
  truncation=True,
33
  return_tensors='pt').to(DEVICE)
34
- scores = {}
35
  for foundation in FOUNDATIONS:
36
  model = MODELS[foundation]
37
  outputs = model(**inputs)
38
  outputs = torch.softmax(outputs.logits, dim=1)
39
  outputs = outputs[:, 1]
40
  score = outputs.detach().cpu().numpy()[0]
41
- scores[foundation.capitalize()] = score
 
42
  return scores
43
 
44
 
@@ -56,12 +57,13 @@ demo = gr.Interface(
56
  ),
57
  ],
58
  outputs=[
59
- CustomLabel(
60
- label="Moral foundations scores",
61
- container=False,
62
- show_label=False,
63
- scale=10,
64
- lines=10,
 
65
  )
66
  ],
67
  )
 
6
  )
7
  from typing import Dict
8
  import os
9
+ import pandas as pd
10
  from huggingface_hub import login
11
  login(token=os.getenv("HUGGINGFACE_TOKEN"))
12
 
 
31
  padding=True,
32
  truncation=True,
33
  return_tensors='pt').to(DEVICE)
34
+ scores = []
35
  for foundation in FOUNDATIONS:
36
  model = MODELS[foundation]
37
  outputs = model(**inputs)
38
  outputs = torch.softmax(outputs.logits, dim=1)
39
  outputs = outputs[:, 1]
40
  score = outputs.detach().cpu().numpy()[0]
41
+ scores.append([foundation.capitalize, score])
42
+ scores = pd.DataFrame(scores, columns=["foundation", "score"])
43
  return scores
44
 
45
 
 
57
  ),
58
  ],
59
  outputs=[
60
+ gr.BarPlot(
61
+ x="foundation",
62
+ y="score",
63
+ title="Moral foundations scores",
64
+ x_title="Moral foundation",
65
+ y_title="Score",
66
+ vertical=True,
67
  )
68
  ],
69
  )
requirements.txt CHANGED
@@ -4,3 +4,4 @@ scipy==1.11.1
4
  sentencepiece==0.1.99
5
  torch==2.0.1
6
  transformers==4.31.0
 
 
4
  sentencepiece==0.1.99
5
  torch==2.0.1
6
  transformers==4.31.0
7
+ pandas==1.5.3