belyakoff commited on
Commit
41abbd4
·
verified ·
1 Parent(s): f1c1728

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -5,11 +5,10 @@ from transformers import pipeline
5
 
6
  classifier = pipeline("zero-shot-classification", model=os.getenv('MODEL'))
7
 
8
- @spaces.GPU(duration=120)
9
  def classify(text, labels):
10
- labels = labels.split(',')
11
  if not text or not labels:
12
  return []
 
13
  result = classifier(text, candidate_labels=labels)
14
  return list(zip(result['labels'], map(lambda x: round(x, 4), result['scores'])))
15
 
@@ -17,12 +16,12 @@ with gr.Blocks() as demo:
17
  gr.Markdown("## Zero-Shot классификация")
18
 
19
  with gr.Row():
20
- with gr.Column(scale=2):
21
  text_input = gr.Textbox(label="Текст для классификации", placeholder="Введите текст...")
22
  labels_input = gr.Textbox(label="Классы (через запятую)", placeholder="Опишите классы через запятую...")
23
  button = gr.Button("Classify")
24
  with gr.Column(scale=1):
25
- output = gr.Dataframe(headers=[" Класс ", "Вероятность"], label="Результаты классификации")
26
 
27
  button.click(classify, inputs=[text_input, labels_input], outputs=output)
28
 
 
5
 
6
  classifier = pipeline("zero-shot-classification", model=os.getenv('MODEL'))
7
 
 
8
  def classify(text, labels):
 
9
  if not text or not labels:
10
  return []
11
+ labels = list(map(lambda x: x.strip(), labels.split(',')))
12
  result = classifier(text, candidate_labels=labels)
13
  return list(zip(result['labels'], map(lambda x: round(x, 4), result['scores'])))
14
 
 
16
  gr.Markdown("## Zero-Shot классификация")
17
 
18
  with gr.Row():
19
+ with gr.Column(scale=1.5):
20
  text_input = gr.Textbox(label="Текст для классификации", placeholder="Введите текст...")
21
  labels_input = gr.Textbox(label="Классы (через запятую)", placeholder="Опишите классы через запятую...")
22
  button = gr.Button("Classify")
23
  with gr.Column(scale=1):
24
+ output = gr.Dataframe(headers=["Класс", "Вероятность"], label="Результаты классификации")
25
 
26
  button.click(classify, inputs=[text_input, labels_input], outputs=output)
27