saeedabc commited on
Commit
13ecc63
·
1 Parent(s): dd4b76a

Fixed downloading nltk.punkt

Browse files
Files changed (2) hide show
  1. app.py +11 -22
  2. util.py +2 -5
app.py CHANGED
@@ -13,21 +13,20 @@ import ruptures as rpt
13
  from util import sent_tokenize
14
 
15
 
16
- # _OPENAI_MODELS = ['text-embedding-ada-002', 'text-embedding-3-small', 'text-embedding-3-large']
17
  _ST_MODELS = ['all-mpnet-base-v2', 'multi-qa-mpnet-base-dot-v1', 'all-MiniLM-L12-v2']
18
 
19
  CACHE_DIR = '.cache'
20
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
21
 
22
  plt.rcParams.update({
23
- 'font.family': 'Times New Roman', #'Arial', # or 'Helvetica', 'Times New Roman'
24
- 'font.size': 12, # General font size
25
- 'axes.titlesize': 13, # Font size for titles
26
- 'axes.labelsize': 12, # Font size for axis labels
27
- 'xtick.labelsize': 11, # Font size for x-tick labels
28
- 'ytick.labelsize': 11, # Font size for y-tick labels
29
- 'legend.fontsize': 11, # Font size for legend
30
- 'legend.title_fontsize': 11 # Font size for legend title
31
  })
32
 
33
 
@@ -129,11 +128,6 @@ def output_segments(sents, preds, probs):
129
  preds = preds + [1]
130
  bkps = get_bkps_from_labels(preds)
131
 
132
- # print(f'signal(#{len(signal)}): {signal}')
133
- # print(f'bkps(#{len(bkps)}): {bkps}')
134
- # if not bkps or bkps[-1] != len(signal):
135
- # print('Note: last segment is incomplete!')
136
-
137
  fig, [ax] = rpt.display(np.array(signal), bkps, figsize=(10, 5), dpi=250)
138
  y_min = max(0.0, min(signal) - 0.1)
139
  y_max = min(1.0, max(signal) + 0.1)
@@ -170,16 +164,11 @@ def text_segmentation(input_text, model_name, k, pool, threshold):
170
  return output_segments(sents, preds, probs)
171
 
172
 
173
- # with gr.Blocks(css=".custom-tab { padding: 20px; margin: 20px; }") as app:
174
  with gr.Blocks() as app:
175
  gr.Markdown("""
176
  # LLM TextTiling Demo
177
 
178
- An **extended** approach to text segmentation that combines **TextTiling** with **LLM embeddings**.
179
- Simply provide your text, choose an embedding model, and adjust segmentation parameters (window size, threshold, pooling).
180
- The demo will split your text into coherent segments based on **semantic shifts**.
181
-
182
- [**View the code on GitHub**](https://github.com/saeedabc/llm-text-tiling/demo)
183
  """)
184
 
185
  with gr.Row():
@@ -210,7 +199,7 @@ The demo will split your text into coherent segments based on **semantic shifts*
210
  output_text = gr.Textbox(label="Output Text", placeholder="Chunks will appear here...", lines=22)
211
  with gr.Tab("Output Json"):
212
  output_json = gr.Json(label="Output Json", open=False, max_height=500)
213
- with gr.Tab("Output Visualization"): #, elem_classes="custom-tab"):
214
  output_fig = gr.Plot(label="Output Visualization")
215
 
216
  submit_button.click(text_segmentation, inputs=[input_text, model_name, k, pool, threshold], outputs=[output_text, output_json, output_fig])
@@ -233,4 +222,4 @@ if __name__ == '__main__':
233
  Path(CACHE_DIR).mkdir(exist_ok=True)
234
 
235
  # Launch the app
236
- app.launch() # share=True)
 
13
  from util import sent_tokenize
14
 
15
 
 
16
  _ST_MODELS = ['all-mpnet-base-v2', 'multi-qa-mpnet-base-dot-v1', 'all-MiniLM-L12-v2']
17
 
18
  CACHE_DIR = '.cache'
19
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
20
 
21
  plt.rcParams.update({
22
+ 'font.family': 'Times New Roman', #'Arial', 'Helvetica'
23
+ 'font.size': 12,
24
+ 'axes.titlesize': 13,
25
+ 'axes.labelsize': 12,
26
+ 'xtick.labelsize': 11,
27
+ 'ytick.labelsize': 11,
28
+ 'legend.fontsize': 11,
29
+ 'legend.title_fontsize': 11
30
  })
31
 
32
 
 
128
  preds = preds + [1]
129
  bkps = get_bkps_from_labels(preds)
130
 
 
 
 
 
 
131
  fig, [ax] = rpt.display(np.array(signal), bkps, figsize=(10, 5), dpi=250)
132
  y_min = max(0.0, min(signal) - 0.1)
133
  y_max = min(1.0, max(signal) + 0.1)
 
164
  return output_segments(sents, preds, probs)
165
 
166
 
 
167
  with gr.Blocks() as app:
168
  gr.Markdown("""
169
  # LLM TextTiling Demo
170
 
171
+ An **extended** approach to text segmentation that combines **TextTiling** with **LLM embeddings**. Simply provide your text, choose an embedding model, and adjust segmentation parameters (window size, pooling, threshold). The demo will split your text into coherent segments based on **semantic shifts**.
 
 
 
 
172
  """)
173
 
174
  with gr.Row():
 
199
  output_text = gr.Textbox(label="Output Text", placeholder="Chunks will appear here...", lines=22)
200
  with gr.Tab("Output Json"):
201
  output_json = gr.Json(label="Output Json", open=False, max_height=500)
202
+ with gr.Tab("Output Visualization"):
203
  output_fig = gr.Plot(label="Output Visualization")
204
 
205
  submit_button.click(text_segmentation, inputs=[input_text, model_name, k, pool, threshold], outputs=[output_text, output_json, output_fig])
 
222
  Path(CACHE_DIR).mkdir(exist_ok=True)
223
 
224
  # Launch the app
225
+ app.launch()
util.py CHANGED
@@ -1,10 +1,7 @@
1
- import os
2
-
3
-
4
  ### NLTK ###
5
  import nltk
6
- if not os.path.exists(os.path.join(nltk.data.find('tokenizers'), 'punkt')):
7
- nltk.download('punkt')
8
 
9
  def nltk_sent_tokenize(texts: list[str]):
10
  return (sent for text in texts for sent in nltk.sent_tokenize(text))
 
 
 
 
1
  ### NLTK ###
2
  import nltk
3
+
4
+ nltk.download('punkt')
5
 
6
  def nltk_sent_tokenize(texts: list[str]):
7
  return (sent for text in texts for sent in nltk.sent_tokenize(text))