sitammeur commited on
Commit
b27f0b5
1 Parent(s): fa6d28a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -2,6 +2,7 @@
2
  import gradio as gr
3
  from transformers import pipeline
4
 
 
5
  # Load the zero-shot classification model
6
  classifier = pipeline(
7
  "zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0"
@@ -24,18 +25,13 @@ def ZeroShotTextClassification(text_input, candidate_labels):
24
  # Split the candidate labels
25
  labels = [label.strip(" ") for label in candidate_labels.split(",")]
26
 
27
- # Output dictionary to store the predicted labels and their scores
28
- output = {}
29
-
30
  # Perform zero-shot classification
31
  prediction = classifier(text_input, labels)
32
 
33
- # Create a dictionary with the predicted labels and their corresponding scores
34
- for i in range(len(prediction["labels"])):
35
- output[prediction["labels"][i]] = prediction["scores"][i]
36
-
37
- # Return the output
38
- return output
39
 
40
 
41
  # Examples to display in the interface
 
2
  import gradio as gr
3
  from transformers import pipeline
4
 
5
+
6
  # Load the zero-shot classification model
7
  classifier = pipeline(
8
  "zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0"
 
25
  # Split the candidate labels
26
  labels = [label.strip(" ") for label in candidate_labels.split(",")]
27
 
 
 
 
28
  # Perform zero-shot classification
29
  prediction = classifier(text_input, labels)
30
 
31
+ return {
32
+ prediction["labels"][i]: prediction["scores"][i]
33
+ for i in range(len(prediction["labels"]))
34
+ }
 
 
35
 
36
 
37
  # Examples to display in the interface