belyakoff commited on
Commit
243a0e8
·
verified ·
1 Parent(s): f7503c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -24
app.py CHANGED
@@ -4,41 +4,25 @@ from transformers import pipeline
4
  classifier = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7")
5
 
6
  def classify(text, *labels):
7
- labels = [label for label in labels if label]
8
  if not text or not labels:
9
  return []
10
  result = classifier(text, candidate_labels=labels)
11
  return list(zip(result['labels'], result['scores']))
12
 
13
- def dynamic_ui(labels):
14
- inputs = [gr.Textbox(label=f"Class {i+1}", value=label) for i, label in enumerate(labels)]
15
- return inputs
16
-
17
  with gr.Blocks() as demo:
18
  gr.Markdown("## Zero-Shot Text Classification")
19
-
20
- text_input = gr.Textbox(label="Text for classification")
21
-
22
- labels = [""]
23
- classes_container = gr.Column(dynamic_ui(labels))
24
-
25
- def add_class():
26
- labels.append("")
27
- return dynamic_ui(labels)
28
 
29
- def remove_class():
30
- if labels:
31
- labels.pop()
32
- return dynamic_ui(labels)
33
-
34
- add_button = gr.Button("Add Class")
35
- remove_button = gr.Button("Remove Class")
36
- add_button.click(add_class, [], classes_container)
37
- remove_button.click(remove_class, [], classes_container)
38
 
39
  output = gr.Dataframe(headers=["Class", "Probability"], label="Classification Results")
 
 
 
 
40
  button = gr.Button("Classify")
41
 
42
- button.click(classify, inputs=[text_input] + classes_container.children, outputs=output)
43
 
44
  demo.launch()
 
4
  classifier = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7")
5
 
6
  def classify(text, *labels):
7
+ labels = list(labels)
8
  if not text or not labels:
9
  return []
10
  result = classifier(text, candidate_labels=labels)
11
  return list(zip(result['labels'], result['scores']))
12
 
 
 
 
 
13
  with gr.Blocks() as demo:
14
  gr.Markdown("## Zero-Shot Text Classification")
 
 
 
 
 
 
 
 
 
15
 
16
+ text_input = gr.Textbox(label="Text for classification")
17
+ labels_input = gr.Textbox(label="Classes (comma separated)")
 
 
 
 
 
 
 
18
 
19
  output = gr.Dataframe(headers=["Class", "Probability"], label="Classification Results")
20
+
21
+ def update_classes(classes):
22
+ return classes.split(',')
23
+
24
  button = gr.Button("Classify")
25
 
26
+ button.click(classify, inputs=[text_input, labels_input], outputs=output)
27
 
28
  demo.launch()