Oliver Li commited on
Commit
c193e45
·
1 Parent(s): 58e1783
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -3,11 +3,17 @@ import pandas as pd
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
 
5
  # Function to load the pre-trained model
6
- def load_model(model_name):
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  return tokenizer, model
10
 
 
 
 
 
 
 
11
  # Streamlit app
12
  st.title("Multi-label Toxicity Detection App")
13
  st.write("Enter a text and select the fine-tuned model to get the toxicity analysis.")
@@ -49,13 +55,9 @@ if st.button("Analyze"):
49
  else:
50
  with st.spinner("Analyzing toxicity..."):
51
  if selected_model == "Olivernyu/finetuned_bert_base_uncased":
52
- tokenizer, model = load_model(selected_model)
53
- toxicity_detector = pipeline("text-classification", tokenizer=tokenizer, model=model)
54
  outputs = toxicity_detector(text, top_k=2)
55
-
56
  category_names = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
57
-
58
-
59
  results = []
60
  for item in outputs:
61
  results.append((category[item['label']], item['score']))
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
 
5
  # Function to load the pre-trained model
6
+ def load_finetune_model(model_name):
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  return tokenizer, model
10
 
11
+ def load_model(model_name):
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
14
+ sentiment_pipeline = pipeline("sentiment-analysis", tokenizer=tokenizer, model=model)
15
+ return sentiment_pipeline
16
+
17
  # Streamlit app
18
  st.title("Multi-label Toxicity Detection App")
19
  st.write("Enter a text and select the fine-tuned model to get the toxicity analysis.")
 
55
  else:
56
  with st.spinner("Analyzing toxicity..."):
57
  if selected_model == "Olivernyu/finetuned_bert_base_uncased":
58
+ toxicity_detector = load_finetune_model(selected_model)
 
59
  outputs = toxicity_detector(text, top_k=2)
 
60
  category_names = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
 
 
61
  results = []
62
  for item in outputs:
63
  results.append((category[item['label']], item['score']))