NativeVex commited on
Commit
e043bee
·
1 Parent(s): fb22d07

get ready for branching off for hf

Browse files
Files changed (1) hide show
  1. language_models_project/app.py +33 -0
language_models_project/app.py CHANGED
@@ -22,6 +22,7 @@ model_name = st.selectbox(
22
  "finiteautomata/bertweet-base-sentiment-analysis",
23
  "ahmedrachid/FinancialBERT-Sentiment-Analysis",
24
  "finiteautomata/beto-sentiment-analysis",
 
25
  ],
26
  )
27
 
@@ -29,6 +30,38 @@ input_sentences = st.text_area("Sentences", value=demo_phrases, height=200)
29
 
30
  data = input_sentences.split("\n")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if st.button("Classify"):
33
  st.write("Please allow a few minutes for the model to run/download")
34
  for i in range(len(data)):
 
22
  "finiteautomata/bertweet-base-sentiment-analysis",
23
  "ahmedrachid/FinancialBERT-Sentiment-Analysis",
24
  "finiteautomata/beto-sentiment-analysis",
25
+ "NativeVex/custom-fine-tuned"
26
  ],
27
  )
28
 
 
30
 
31
  data = input_sentences.split("\n")
32
 
33
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
34
+ model_path = "bin/model4"
35
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
36
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
37
+
38
+ from typing import List
39
+ import torch
40
+ import numpy as np
41
+ import pandas as pd
42
+
43
+ def infer(text: str) -> List[float]:
44
+ encoding = tokenizer(text, return_tensors="pt")
45
+ encoding = {k: v.to(model.device) for k,v in encoding.items()}
46
+ outputs = model(**encoding)
47
+ logits = outputs.logits
48
+ sigmoid = torch.nn.Sigmoid()
49
+ probs = sigmoid(logits.squeeze().cpu())
50
+ predictions = np.zeros(probs.shape)
51
+ predictions[np.where(probs >= 0.5)] = 1
52
+ predictions = pd.Series(predictions == 1)
53
+ predictions.index = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
54
+ return predictions
55
+
56
+
57
+ def wrapper(*args, **kwargs):
58
+ if args[0] != "NativeVex/custom-fine-tuned":
59
+ return classify(*args, **kwargs)
60
+ else:
61
+ return infer(text=args[1])
62
+
63
+
64
+
65
  if st.button("Classify"):
66
  st.write("Please allow a few minutes for the model to run/download")
67
  for i in range(len(data)):