sileod commited on
Commit
66a71cd
Β·
verified Β·
1 Parent(s): 9e973e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -45
app.py CHANGED
@@ -2,8 +2,8 @@ import gradio as gr
2
  from transformers import pipeline
3
  import re
4
 
 
5
  def sent_tokenize(text):
6
- # Regular expression to split sentences
7
  sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)(\s|$)')
8
  sentences = sentence_endings.split(text)
9
  return [s.strip() for s in sentences if s.strip()]
@@ -34,6 +34,70 @@ long_context_examples = [
34
  "The cafe is experiencing a slow, quiet morning"]
35
  ]
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def process_input(text_input, labels_or_premise, mode):
38
  if mode == "Zero-Shot Classification":
39
  labels = [label.strip() for label in labels_or_premise.split(',')]
@@ -48,6 +112,7 @@ def process_input(text_input, labels_or_premise, mode):
48
  # Global prediction
49
  global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
50
  global_results = {pred['label']: pred['score'] for pred in global_pred}
 
51
 
52
  # Sentence-level analysis
53
  sentences = sent_tokenize(text_input)
@@ -63,50 +128,10 @@ def process_input(text_input, labels_or_premise, mode):
63
  'scores': sent_scores
64
  })
65
 
66
- # Create markdown analysis
67
- analysis_md = "## Global Prediction\n"
68
- max_global_label = max(global_results.items(), key=lambda x: x[1])[0]
69
- analysis_md += f"Overall prediction: **{max_global_label}**\n\n"
70
- analysis_md += "## Sentence-Level Analysis\n"
71
-
72
- for i, result in enumerate(sentence_results, 1):
73
- analysis_md += f"\n### Sentence {i}\n"
74
- analysis_md += f"*{result['sentence']}*\n"
75
- analysis_md += f"Prediction: **{result['prediction']}**\n"
76
- scores_str = ", ".join([f"{label}: {score:.2f}" for label, score in result['scores'].items()])
77
- analysis_md += f"Scores: {scores_str}\n"
78
-
79
- return global_results, analysis_md
80
-
81
- def update_interface(mode):
82
- if mode == "Zero-Shot Classification":
83
- return (
84
- gr.update(
85
- label="🏷️ Categories",
86
- placeholder="Enter comma-separated categories...",
87
- value=zero_shot_examples[0][1]
88
- ),
89
- gr.update(value=zero_shot_examples[0][0])
90
- )
91
- elif mode == "Natural Language Inference":
92
- return (
93
- gr.update(
94
- label="πŸ”Ž Hypothesis",
95
- placeholder="Enter a hypothesis to compare with the premise...",
96
- value=nli_examples[0][1]
97
- ),
98
- gr.update(value=nli_examples[0][0])
99
- )
100
- else: # Long Context NLI
101
- return (
102
- gr.update(
103
- label="πŸ”Ž Global Hypothesis",
104
- placeholder="Enter a hypothesis to test against the full context...",
105
- value=long_context_examples[0][1]
106
- ),
107
- gr.update(value=long_context_examples[0][0])
108
- )
109
 
 
110
  with gr.Blocks() as demo:
111
  gr.Markdown("""
112
  # tasksource/ModernBERT-nli demonstration
@@ -142,7 +167,7 @@ with gr.Blocks() as demo:
142
 
143
  outputs = [
144
  gr.Label(label="πŸ“Š Results"),
145
- gr.Markdown(label="πŸ“ˆ Sentence Analysis", visible=True)
146
  ]
147
 
148
  with gr.Column(variant="panel") as zero_shot_examples_panel:
 
2
  from transformers import pipeline
3
  import re
4
 
5
+ # Custom sentence tokenizer
6
  def sent_tokenize(text):
 
7
  sentence_endings = re.compile(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)(\s|$)')
8
  sentences = sentence_endings.split(text)
9
  return [s.strip() for s in sentences if s.strip()]
 
34
  "The cafe is experiencing a slow, quiet morning"]
35
  ]
36
 
37
+ def get_label_color(label):
38
+ """Return color based on NLI label."""
39
+ colors = {
40
+ 'ENTAILMENT': '#90EE90', # Light green
41
+ 'NEUTRAL': '#FFE5B4', # Peach
42
+ 'CONTRADICTION': '#FFB6C1' # Light pink
43
+ }
44
+ return colors.get(label, '#FFFFFF')
45
+
46
+ def create_analysis_html(sentence_results, global_label):
47
+ """Create HTML table for sentence analysis with color coding."""
48
+ html = """
49
+ <style>
50
+ .analysis-table {
51
+ width: 100%;
52
+ border-collapse: collapse;
53
+ margin: 20px 0;
54
+ font-family: Arial, sans-serif;
55
+ }
56
+ .analysis-table th, .analysis-table td {
57
+ padding: 12px;
58
+ border: 1px solid #ddd;
59
+ text-align: left;
60
+ }
61
+ .analysis-table th {
62
+ background-color: #f5f5f5;
63
+ }
64
+ .global-prediction {
65
+ padding: 15px;
66
+ margin: 20px 0;
67
+ border-radius: 5px;
68
+ font-weight: bold;
69
+ }
70
+ </style>
71
+ """
72
+
73
+ # Add global prediction box
74
+ html += f"""
75
+ <div class="global-prediction" style="background-color: {get_label_color(global_label)}">
76
+ Global Prediction: {global_label}
77
+ </div>
78
+ """
79
+
80
+ # Create table
81
+ html += """
82
+ <table class="analysis-table">
83
+ <tr>
84
+ <th>Sentence</th>
85
+ <th>Prediction</th>
86
+ </tr>
87
+ """
88
+
89
+ # Add rows for each sentence
90
+ for result in sentence_results:
91
+ html += f"""
92
+ <tr style="background-color: {get_label_color(result['prediction'])}">
93
+ <td>{result['sentence']}</td>
94
+ <td>{result['prediction']}</td>
95
+ </tr>
96
+ """
97
+
98
+ html += "</table>"
99
+ return html
100
+
101
  def process_input(text_input, labels_or_premise, mode):
102
  if mode == "Zero-Shot Classification":
103
  labels = [label.strip() for label in labels_or_premise.split(',')]
 
112
  # Global prediction
113
  global_pred = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}], return_all_scores=True)[0]
114
  global_results = {pred['label']: pred['score'] for pred in global_pred}
115
+ global_label = max(global_results.items(), key=lambda x: x[1])[0]
116
 
117
  # Sentence-level analysis
118
  sentences = sent_tokenize(text_input)
 
128
  'scores': sent_scores
129
  })
130
 
131
+ analysis_html = create_analysis_html(sentence_results, global_label)
132
+ return global_results, analysis_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # [Previous interface code remains the same until the outputs definition]
135
  with gr.Blocks() as demo:
136
  gr.Markdown("""
137
  # tasksource/ModernBERT-nli demonstration
 
167
 
168
  outputs = [
169
  gr.Label(label="πŸ“Š Results"),
170
+ gr.HTML(label="πŸ“ˆ Sentence Analysis") # Changed from Markdown to HTML
171
  ]
172
 
173
  with gr.Column(variant="panel") as zero_shot_examples_panel: