dejanseo commited on
Commit
fdfad5d
·
verified ·
1 Parent(s): 8cbe5a9

Delete src

Browse files
src/.streamlit/config.toml DELETED
@@ -1,24 +0,0 @@
1
- [server]
2
- enableStaticServing = true
3
- toolbarMode = "viewer"
4
-
5
- [[theme.fontFaces]]
6
- family="montserrat-sans"
7
- url="app/static/Montserrat-Italic-VariableFont_wght.ttf"
8
- style="italic"
9
- weight=500
10
- [[theme.fontFaces]]
11
- family="montserrat-sans"
12
- url="app/static/Montserrat-VariableFont_wght.ttf"
13
- style="normal"
14
- weight=500
15
- [[theme.fontFaces]]
16
- family="noto-mono"
17
- url="app/static/NotoSansMono-VariableFont_wdth,wght.ttf"
18
-
19
- [theme]
20
- font="montserrat-sans, noto-sans, sans-serif"
21
- codeFont="noto-mono, monospace"
22
- baseFontSize=16
23
- primaryColor="#28a745"
24
- backgroundColor="#FFFFFF"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/app.py DELETED
@@ -1,298 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- # Use AutoModel and AutoTokenizer for easier loading from Hub
4
- from transformers import AutoModelForTokenClassification, AutoTokenizer
5
- import numpy as np
6
- import logging
7
- from dataclasses import dataclass
8
- from typing import Optional, Dict, List, Tuple
9
-
10
- # --- HIDE STREAMLIT MENU ---
11
- st.set_page_config(
12
- initial_sidebar_state="collapsed"
13
- )
14
-
15
- hide_streamlit_style = """
16
- <style>
17
- #MainMenu {visibility: hidden;}
18
- </style>
19
- """
20
- st.markdown(hide_streamlit_style, unsafe_allow_html=True)
21
-
22
- st.logo(
23
- image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
24
- link="https://dejan.ai/",
25
- )
26
-
27
- # ----------------------------------
28
- # Logging
29
- # ----------------------------------
30
- logging.basicConfig(level=logging.INFO)
31
- logger = logging.getLogger(__name__)
32
-
33
- # ----------------------------------
34
- # Config
35
- # ----------------------------------
36
-
37
- @dataclass
38
- class AppConfig:
39
- """Configuration for the LinkBERT application"""
40
- # <<< CHANGE 1: Point to the Hugging Face Hub repository >>>
41
- model_name: str = "dejanseo/link-prediction"
42
- max_length: int = 512
43
- doc_stride: int = 128
44
- device: str = "cuda" if torch.cuda.is_available() else "cpu"
45
-
46
- # ----------------------------------
47
- # Load model/tokenizer from Hugging Face Hub
48
- # ----------------------------------
49
-
50
- @st.cache_resource
51
- def load_model_from_hub():
52
- """Loads the fine-tuned model and tokenizer from the Hugging Face Hub."""
53
- config = AppConfig()
54
-
55
- logger.info(f"Loading model and tokenizer from Hugging Face Hub: {config.model_name}...")
56
-
57
- # <<< CHANGE 2: Use Auto* classes for direct loading from the Hub >>>
58
- model = AutoModelForTokenClassification.from_pretrained(config.model_name)
59
- tokenizer = AutoTokenizer.from_pretrained(config.model_name)
60
-
61
- model.to(config.device)
62
- model.eval()
63
-
64
- logger.info("Model and tokenizer loaded successfully.")
65
- return model, tokenizer, config.device, config.max_length, config.doc_stride
66
-
67
- # <<< CHANGE 3: Call the new loading function >>>
68
- model, tokenizer, device, MAX_LENGTH, DOC_STRIDE = load_model_from_hub()
69
-
70
-
71
- # ----------------------------------
72
- # Inference helpers
73
- # ----------------------------------
74
-
75
- def windowize_inference(
76
- plain_text: str, tokenizer: AutoTokenizer, max_length: int, doc_stride: int
77
- ) -> List[Dict]:
78
- """Slice long text into overlapping windows for inference."""
79
- specials = tokenizer.num_special_tokens_to_add(pair=False)
80
- cap = max_length - specials
81
- full_encoding = tokenizer(
82
- plain_text, add_special_tokens=False, return_offsets_mapping=True, truncation=False
83
- )
84
- temp_tokenization = tokenizer(plain_text, truncation=False)
85
- full_word_ids = temp_tokenization.word_ids(batch_index=0)
86
-
87
- windows_data = []
88
- step = max(cap - doc_stride, 1)
89
- start_token_idx = 0
90
- total_tokens = len(full_encoding["input_ids"])
91
-
92
- if total_tokens == 0 and len(plain_text) > 0:
93
- logger.warning("Tokenizer produced 0 tokens for a non-empty string.")
94
- return []
95
-
96
- while start_token_idx < total_tokens:
97
- end_token_idx = min(start_token_idx + cap, total_tokens)
98
- ids_slice = full_encoding["input_ids"][start_token_idx:end_token_idx]
99
- offsets_slice = full_encoding["offset_mapping"][start_token_idx:end_token_idx]
100
-
101
- # Properly slice word_ids based on character spans
102
- word_ids_slice = []
103
- current_token = 0
104
- for i, wid in enumerate(full_word_ids):
105
- if temp_tokenization.token_to_chars(i) is not None:
106
- if current_token >= start_token_idx and current_token < end_token_idx:
107
- word_ids_slice.append(wid)
108
- current_token += 1
109
-
110
- input_ids = tokenizer.build_inputs_with_special_tokens(ids_slice)
111
- attention_mask = [1] * len(input_ids)
112
- padding_length = max_length - len(input_ids)
113
- input_ids.extend([tokenizer.pad_token_id] * padding_length)
114
- attention_mask.extend([0] * padding_length)
115
-
116
- # Pad offset mapping correctly
117
- window_offset_mapping = tokenizer.build_inputs_with_special_tokens([]) # Get special tokens offsets
118
- window_offset_mapping = window_offset_mapping[:-1] + offsets_slice + window_offset_mapping[-1:]
119
- window_offset_mapping += [(0, 0)] * padding_length
120
-
121
- window_word_ids = [None] + word_ids_slice + [None] * (padding_length + 1)
122
-
123
- windows_data.append({
124
- "input_ids": torch.tensor(input_ids, dtype=torch.long),
125
- "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
126
- "word_ids": window_word_ids[:max_length],
127
- "offset_mapping": window_offset_mapping[:max_length],
128
- })
129
- if end_token_idx >= total_tokens: break
130
- start_token_idx += step
131
- return windows_data
132
-
133
- def classify_text(
134
- text: str, prediction_threshold_percent: float
135
- ) -> Tuple[str, Optional[str]]:
136
- """Classify link tokens with windowing. Returns (html, warning)."""
137
- if not text.strip(): return "", "Input text is empty."
138
-
139
- windows = windowize_inference(text, tokenizer, MAX_LENGTH, DOC_STRIDE)
140
- if not windows: return "", "Could not generate any windows for processing."
141
-
142
- char_link_probabilities = np.zeros(len(text), dtype=np.float32)
143
-
144
- with torch.no_grad():
145
- for window in windows:
146
- inputs = {
147
- 'input_ids': window['input_ids'].unsqueeze(0).to(device),
148
- 'attention_mask': window['attention_mask'].unsqueeze(0).to(device)
149
- }
150
- # <<< CHANGE 4: The output object from AutoModel has a 'logits' attribute directly >>>
151
- outputs = model(**inputs)
152
- probabilities = torch.softmax(outputs.logits, dim=-1).squeeze(0)
153
- link_probs = probabilities[:, 1].cpu().numpy()
154
-
155
- for i, (start, end) in enumerate(window['offset_mapping']):
156
- if window['word_ids'][i] is not None and start < end:
157
- char_link_probabilities[start:end] = np.maximum(
158
- char_link_probabilities[start:end], link_probs[i]
159
- )
160
-
161
- final_threshold = prediction_threshold_percent / 100.0
162
-
163
- full_encoding = tokenizer(text, return_offsets_mapping=True, truncation=False)
164
- word_ids = full_encoding.word_ids(batch_index=0)
165
- offsets = full_encoding['offset_mapping']
166
-
167
- word_max_prob_map: Dict[int, float] = {}
168
- word_char_spans: Dict[int, List[int]] = {}
169
-
170
- for i, word_id in enumerate(word_ids):
171
- if word_id is not None:
172
- start_char, end_char = offsets[i]
173
- if start_char < end_char:
174
- current_token_max_prob = np.max(char_link_probabilities[start_char:end_char]) if np.any(char_link_probabilities[start_char:end_char]) else 0.0
175
-
176
- if word_id not in word_max_prob_map:
177
- word_max_prob_map[word_id] = current_token_max_prob
178
- word_char_spans[word_id] = [start_char, end_char]
179
- else:
180
- word_max_prob_map[word_id] = max(word_max_prob_map[word_id], current_token_max_prob)
181
- word_char_spans[word_id][1] = end_char
182
-
183
- highlight_candidates: Dict[int, float] = {}
184
- for word_id, max_prob in word_max_prob_map.items():
185
- if max_prob >= final_threshold:
186
- highlight_candidates[word_id] = max_prob
187
-
188
- max_highlight_prob = 0.0
189
- if highlight_candidates:
190
- max_highlight_prob = max(highlight_candidates.values())
191
-
192
- html_parts, current_char = [], 0
193
- sorted_word_ids = sorted(word_char_spans.keys(), key=lambda k: word_char_spans[k][0])
194
-
195
- for word_id in sorted_word_ids:
196
- start_char, end_char = word_char_spans[word_id]
197
-
198
- if start_char > current_char:
199
- html_parts.append(text[current_char:start_char])
200
-
201
- word_text = text[start_char:end_char]
202
-
203
- if word_id in highlight_candidates:
204
- word_prob = highlight_candidates[word_id]
205
- normalized_opacity = 1.0
206
- if max_highlight_prob > 0:
207
- normalized_opacity = (word_prob / max_highlight_prob) * 0.9 + 0.1
208
-
209
- base_bg_color = "#D4EDDA"
210
- base_text_color = "#155724"
211
-
212
- html_parts.append(f"<span style='background-color: {base_bg_color}; color: {base_text_color}; "
213
- f"padding: 0.1em 0.2em; border-radius: 0.2em; opacity: {normalized_opacity:.2f};'>"
214
- f"{word_text}</span>")
215
- else:
216
- html_parts.append(word_text)
217
- current_char = end_char
218
-
219
- if current_char < len(text):
220
- html_parts.append(text[current_char:])
221
-
222
- return "".join(html_parts), None
223
-
224
- # ----------------------------------
225
- # Streamlit UI (No changes needed from here down)
226
- # ----------------------------------
227
- st.set_page_config(layout="wide", page_title="LinkBERT by DEJAN AI")
228
- st.title("LinkBERT")
229
-
230
- DEFAULT_THRESHOLD = 70.0
231
- THRESHOLD_STEP = 10.0
232
- THRESHOLD_BOUNDARY_PERCENT = 10.0
233
-
234
- if 'current_threshold' not in st.session_state:
235
- st.session_state.current_threshold = DEFAULT_THRESHOLD
236
- if 'output_html' not in st.session_state:
237
- st.session_state.output_html = ""
238
- if 'user_input' not in st.session_state:
239
- st.session_state.user_input = "DEJAN AI is the world's leading AI SEO agency. This tool showcases the capability of our latest link prediction model called LinkBERT. This model is trained on the highest quality organic link data and can predict natural link placement in plain text."
240
-
241
- user_input = st.text_area(
242
- "Paste your text here:",
243
- st.session_state.user_input,
244
- height=200,
245
- key="text_area"
246
- )
247
-
248
- with st.expander('Settings'):
249
- slider_threshold = st.slider(
250
- "Link Probability Threshold (%)",
251
- min_value=0, max_value=100, value=int(st.session_state.current_threshold), step=1,
252
- help="The minimum probability for a word to be considered a link candidate."
253
- )
254
-
255
- def run_classification(new_threshold: float):
256
- st.session_state.current_threshold = float(new_threshold)
257
- st.session_state.user_input = user_input
258
- if not st.session_state.user_input.strip():
259
- st.warning("Please enter some text to classify.")
260
- st.session_state.output_html = ""
261
- else:
262
- with st.spinner("Processing..."):
263
- html, warning = classify_text(st.session_state.user_input, st.session_state.current_threshold)
264
- if warning: st.warning(warning)
265
- st.session_state.output_html = html
266
- st.rerun()
267
-
268
- if st.button("Classify Text", type="primary"):
269
- run_classification(slider_threshold)
270
-
271
- if st.session_state.output_html:
272
- st.markdown("---")
273
- st.subheader(f"Results (Threshold: {st.session_state.current_threshold:.1f}%)")
274
- st.markdown(st.session_state.output_html, unsafe_allow_html=True)
275
-
276
- col1, col2, col3 = st.columns(3)
277
-
278
- with col1:
279
- if st.button("Less", icon="➖", use_container_width=True, disabled=not st.session_state.output_html):
280
- current_thr = st.session_state.current_threshold
281
- if current_thr >= (100.0 - THRESHOLD_BOUNDARY_PERCENT):
282
- new_threshold = current_thr + (100.0 - current_thr) / 2.0
283
- else:
284
- new_threshold = current_thr + THRESHOLD_STEP
285
- run_classification(min(100.0, new_threshold))
286
-
287
- with col2:
288
- if st.button("Default", icon="🔄", use_container_width=True, disabled=not st.session_state.output_html):
289
- run_classification(DEFAULT_THRESHOLD)
290
-
291
- with col3:
292
- if st.button("More", icon="➕", use_container_width=True, disabled=not st.session_state.output_html):
293
- current_thr = st.session_state.current_threshold
294
- if current_thr <= THRESHOLD_BOUNDARY_PERCENT:
295
- new_threshold = current_thr / 2.0
296
- else:
297
- new_threshold = current_thr - THRESHOLD_STEP
298
- run_classification(max(0.0, new_threshold))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/packages.txt DELETED
@@ -1,3 +0,0 @@
1
- build-essential
2
- curl
3
- git
 
 
 
 
src/static/Montserrat-Italic-VariableFont_wght.ttf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:acaff344a059669be7699d869c923bd5bb194973dc23748074f3f21deb1452dd
3
- size 701156
 
 
 
 
src/static/Montserrat-VariableFont_wght.ttf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e12288c03e4fa3721aca7ca984f25c042089dc3590e207c43a57199d7b4a5cdb
3
- size 688600
 
 
 
 
src/static/NotoSans-Italic-VariableFont_wdth,wght.ttf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4bf7b366af79c434984d67eae3967e9cd7a2f51c196101c43f21a7e21e608844
3
- size 2300468
 
 
 
 
src/static/NotoSans-VariableFont_wdth,wght.ttf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7b8cac46a1c86d2533a616b1fcf4e1926b8e39bda69034508b0df96791f56d97
3
- size 2044548
 
 
 
 
src/static/NotoSansMono-VariableFont_wdth,wght.ttf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:47d856141683450ee297592b27be447eb5141c68516e5e0e748c66b6e0a54afe
3
- size 1707908
 
 
 
 
src/streamlit_app.py DELETED
@@ -1,328 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
- # Use AutoModel and AutoTokenizer for easier loading from Hub
5
- from transformers import AutoModelForTokenClassification, AutoTokenizer
6
- from pathlib import Path
7
- import numpy as np
8
- import logging
9
- from dataclasses import dataclass
10
- from typing import Optional, Dict, List, Tuple
11
-
12
- # --- HIDE STREAMLIT MENU ---
13
- st.set_page_config(
14
- initial_sidebar_state="collapsed"
15
- )
16
-
17
- hide_streamlit_style = """
18
- <style>
19
- #MainMenu {visibility: hidden;}
20
- </style>
21
- """
22
- st.markdown(hide_streamlit_style, unsafe_allow_html=True)
23
-
24
- st.logo(
25
- image="https://dejan.ai/wp-content/uploads/2024/02/dejan-300x103.png",
26
- link="https://dejan.ai/",
27
- )
28
-
29
- # ----------------------------------
30
- # Logging
31
- # ----------------------------------
32
- logging.basicConfig(level=logging.INFO)
33
- logger = logging.getLogger(__name__)
34
-
35
- # ----------------------------------
36
- # Config
37
- # ----------------------------------
38
-
39
- @dataclass
40
- class AppConfig:
41
- """Configuration for the LinkBERT application"""
42
- # <<< CHANGE 1: Point to the Hugging Face Hub repository >>>
43
- model_name: str = "dejanseo/link-prediction"
44
- max_length: int = 512
45
- doc_stride: int = 128
46
- device: str = "cuda" if torch.cuda.is_available() else "cpu"
47
-
48
- # ----------------------------------
49
- # Load model/tokenizer from Hugging Face Hub
50
- # ----------------------------------
51
-
52
- @st.cache_resource
53
- def load_model_from_hub():
54
- """Loads the fine-tuned model and tokenizer from the Hugging Face Hub."""
55
- config = AppConfig()
56
-
57
- logger.info(f"Loading model and tokenizer from Hugging Face Hub: {config.model_name}...")
58
-
59
- # <<< CHANGE 2: Use Auto* classes for direct loading >>>
60
- # The `AutoModelForTokenClassification` class will automatically find the correct
61
- # model architecture (DeBERTaV2) and load the pre-trained weights, including the
62
- # classification head.
63
- model = AutoModelForTokenClassification.from_pretrained(config.model_name)
64
- tokenizer = AutoTokenizer.from_pretrained(config.model_name)
65
-
66
- model.to(config.device)
67
- model.eval()
68
-
69
- logger.info("Model and tokenizer loaded successfully.")
70
- return model, tokenizer, config.device, config.max_length, config.doc_stride
71
-
72
- # <<< CHANGE 3: Call the new loading function >>>
73
- model, tokenizer, device, MAX_LENGTH, DOC_STRIDE = load_model_from_hub()
74
-
75
-
76
- # ----------------------------------
77
- # Inference helpers (No changes needed here, but forward pass output is slightly different)
78
- # ----------------------------------
79
-
80
- def windowize_inference(
81
- plain_text: str, tokenizer: AutoTokenizer, max_length: int, doc_stride: int
82
- ) -> List[Dict]:
83
- """Slice long text into overlapping windows for inference."""
84
- specials = tokenizer.num_special_tokens_to_add(pair=False)
85
- cap = max_length - specials
86
- full_encoding = tokenizer(
87
- plain_text, add_special_tokens=False, return_offsets_mapping=True, truncation=False
88
- )
89
- # The tokenizer from the Hub might not have a `word_ids` method attached by default
90
- # so we create a temporary tokenization just for that.
91
- temp_tokenization = tokenizer(plain_text, truncation=False)
92
- full_word_ids = temp_tokenization.word_ids(batch_index=0)
93
-
94
-
95
- windows_data = []
96
- # Use max() to prevent step from being 0 if doc_stride is larger than cap
97
- step = max(cap - doc_stride, 1)
98
- start_token_idx = 0
99
- total_tokens = len(full_encoding["input_ids"])
100
-
101
- # Ensure there is at least one window
102
- if total_tokens == 0 and len(plain_text) > 0:
103
- logger.warning("Tokenizer produced 0 tokens for a non-empty string.")
104
- return []
105
-
106
- while start_token_idx < total_tokens:
107
- end_token_idx = min(start_token_idx + cap, total_tokens)
108
- ids_slice = full_encoding["input_ids"][start_token_idx:end_token_idx]
109
- offsets_slice = full_encoding["offset_mapping"][start_token_idx:end_token_idx]
110
- # Adjust word_ids slicing to match token slicing
111
- word_ids_slice = [full_word_ids[i] for i in range(len(temp_tokenization.input_ids)) if temp_tokenization.token_to_chars(i) is not None][start_token_idx:end_token_idx]
112
-
113
- input_ids = tokenizer.build_inputs_with_special_tokens(ids_slice)
114
- attention_mask = [1] * len(input_ids)
115
- padding_length = max_length - len(input_ids)
116
- input_ids.extend([tokenizer.pad_token_id] * padding_length)
117
- attention_mask.extend([0] * padding_length)
118
-
119
- window_offset_mapping = tokenizer.build_inputs_with_special_tokens(
120
- offsets_slice
121
- )
122
- # Pad the offset mapping to match the input_ids length
123
- window_offset_mapping += [(0, 0)] * padding_length
124
-
125
- window_word_ids = [None] + word_ids_slice + [None] * (padding_length + 1)
126
-
127
-
128
- windows_data.append({
129
- "input_ids": torch.tensor(input_ids, dtype=torch.long),
130
- "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
131
- "word_ids": window_word_ids[:max_length],
132
- "offset_mapping": window_offset_mapping[:max_length],
133
- })
134
- if end_token_idx >= total_tokens: break
135
- start_token_idx += step
136
- return windows_data
137
-
138
- def classify_text(
139
- text: str, prediction_threshold_percent: float
140
- ) -> Tuple[str, Optional[str]]:
141
- """Classify link tokens with windowing. Returns (html, warning)."""
142
- if not text.strip(): return "", "Input text is empty."
143
-
144
- windows = windowize_inference(text, tokenizer, MAX_LENGTH, DOC_STRIDE)
145
- if not windows: return "", "Could not generate any windows for processing."
146
-
147
- char_link_probabilities = np.zeros(len(text), dtype=np.float32)
148
-
149
- with torch.no_grad():
150
- for window in windows:
151
- inputs = {
152
- 'input_ids': window['input_ids'].unsqueeze(0).to(device),
153
- 'attention_mask': window['attention_mask'].unsqueeze(0).to(device)
154
- }
155
- # <<< CHANGE 4: The output object from AutoModel has a 'logits' attribute directly >>>
156
- outputs = model(**inputs)
157
- probabilities = torch.softmax(outputs.logits, dim=-1).squeeze(0)
158
- # The model predicts two labels: 0 (not a link), 1 (is a link). We need the prob of label 1.
159
- link_probs = probabilities[:, 1].cpu().numpy()
160
-
161
- for i, (start, end) in enumerate(window['offset_mapping']):
162
- if window['word_ids'][i] is not None and start < end:
163
- # Aggregate probabilities for characters using maximum
164
- char_link_probabilities[start:end] = np.maximum(
165
- char_link_probabilities[start:end], link_probs[i]
166
- )
167
-
168
- final_threshold = prediction_threshold_percent / 100.0
169
-
170
- # Tokenize once to get word_ids and offsets for the full text
171
- full_encoding = tokenizer(text, return_offsets_mapping=True, truncation=False)
172
- word_ids = full_encoding.word_ids(batch_index=0)
173
- offsets = full_encoding['offset_mapping']
174
-
175
- word_max_prob_map: Dict[int, float] = {}
176
- word_char_spans: Dict[int, List[int]] = {}
177
-
178
- for i, word_id in enumerate(word_ids):
179
- if word_id is not None:
180
- start_char, end_char = offsets[i]
181
- if start_char < end_char:
182
- # Get the max probability for the characters spanned by this token
183
- current_token_max_prob = np.max(char_link_probabilities[start_char:end_char])
184
-
185
- # Aggregate max probabilities for each word
186
- if word_id not in word_max_prob_map:
187
- word_max_prob_map[word_id] = current_token_max_prob
188
- word_char_spans[word_id] = [start_char, end_char]
189
- else:
190
- word_max_prob_map[word_id] = max(word_max_prob_map[word_id], current_token_max_prob)
191
- word_char_spans[word_id][1] = end_char # Extend end char for the word
192
-
193
- # Determine words that meet the threshold
194
- highlight_candidates: Dict[int, float] = {}
195
- for word_id, max_prob in word_max_prob_map.items():
196
- if max_prob >= final_threshold:
197
- highlight_candidates[word_id] = max_prob
198
-
199
- # Calculate max probability among highlighted words for normalization
200
- max_highlight_prob = 0.0
201
- if highlight_candidates:
202
- max_highlight_prob = max(highlight_candidates.values())
203
-
204
- # Reconstruct HTML with dynamic opacity
205
- html_parts, current_char = [], 0
206
- # Sort word IDs by their starting character position to reconstruct the text in order
207
- sorted_word_ids = sorted(word_char_spans.keys(), key=lambda k: word_char_spans[k][0])
208
-
209
- for word_id in sorted_word_ids:
210
- start_char, end_char = word_char_spans[word_id]
211
-
212
- # Append any text that is between the last word and this one
213
- if start_char > current_char:
214
- html_parts.append(text[current_char:start_char])
215
-
216
- word_text = text[start_char:end_char]
217
-
218
- if word_id in highlight_candidates:
219
- word_prob = highlight_candidates[word_id]
220
- # Normalize probability to 0.1-1.0 range for opacity
221
- normalized_opacity = 1.0 # Default if no candidates or max_prob is 0
222
- if max_highlight_prob > 0:
223
- normalized_opacity = (word_prob / max_highlight_prob) * 0.9 + 0.1 # Scale to 0.1-1.0
224
-
225
- base_bg_color = "#D4EDDA" # Light green
226
- base_text_color = "#155724" # Dark green
227
-
228
- html_parts.append(f"<span style='background-color: {base_bg_color}; color: {base_text_color}; "
229
- f"padding: 0.1em 0.2em; border-radius: 0.2em; opacity: {normalized_opacity:.2f};'>"
230
- f"{word_text}</span>")
231
- else:
232
- html_parts.append(word_text)
233
- current_char = end_char
234
-
235
- # Append any remaining text at the end
236
- if current_char < len(text):
237
- html_parts.append(text[current_char:])
238
-
239
- return "".join(html_parts), None
240
-
241
-
242
- # ----------------------------------
243
- # Streamlit UI (No changes needed from here down)
244
- # ----------------------------------
245
- st.set_page_config(layout="wide", page_title="LinkBERT by DEJAN AI")
246
- st.title("LinkBERT")
247
-
248
- DEFAULT_THRESHOLD = 70.0
249
- THRESHOLD_STEP = 10.0
250
- THRESHOLD_BOUNDARY_PERCENT = 10.0 # Top/Bottom 10% for half-way logic
251
-
252
- # Initialize session state for threshold and output
253
- if 'current_threshold' not in st.session_state:
254
- st.session_state.current_threshold = DEFAULT_THRESHOLD
255
- if 'output_html' not in st.session_state:
256
- st.session_state.output_html = ""
257
- if 'user_input' not in st.session_state:
258
- st.session_state.user_input = "DEJAN AI is the world's leading AI SEO agency. This tool showcases the capability of our latest link prediction model called LinkBERT. This model is trained on the highest quality organic link data and can predict natural link placement in plain text."
259
-
260
- # --- UI Controls ---
261
- user_input = st.text_area(
262
- "Paste your text here:",
263
- st.session_state.user_input,
264
- height=200,
265
- key="text_area"
266
- )
267
-
268
- with st.expander('Settings'):
269
- slider_threshold = st.slider(
270
- "Link Probability Threshold (%)",
271
- min_value=0, max_value=100, value=int(st.session_state.current_threshold), step=1,
272
- help="The minimum probability for a word to be considered a link candidate."
273
- )
274
-
275
- # --- Classification Function (re-run logic) ---
276
- def run_classification(new_threshold: float):
277
- st.session_state.current_threshold = float(new_threshold)
278
- st.session_state.user_input = user_input # Ensure latest input is used
279
- if not st.session_state.user_input.strip():
280
- st.warning("Please enter some text to classify.")
281
- st.session_state.output_html = ""
282
- else:
283
- with st.spinner("Processing..."):
284
- html, warning = classify_text(st.session_state.user_input, st.session_state.current_threshold)
285
- if warning: st.warning(warning)
286
- st.session_state.output_html = html
287
- st.rerun() # Rerun to update the display immediately
288
-
289
-
290
- # --- Main Classify Button ---
291
- if st.button("Classify Text", type="primary"):
292
- run_classification(slider_threshold)
293
-
294
- # --- Display Output ---
295
- if st.session_state.output_html:
296
- st.markdown("---")
297
- st.subheader(f"Results (Threshold: {st.session_state.current_threshold:.1f}%)")
298
- st.markdown(st.session_state.output_html, unsafe_allow_html=True)
299
-
300
- # --- Adjustment Buttons ---
301
- col1, col2, col3 = st.columns(3)
302
-
303
- with col1:
304
- if st.button("Less", icon=":material/playlist_remove:", use_container_width=True, disabled=not st.session_state.output_html):
305
- current_thr = st.session_state.current_threshold
306
- if current_thr >= (100.0 - THRESHOLD_BOUNDARY_PERCENT): # In top 10% (90-100)
307
- new_threshold = current_thr + (100.0 - current_thr) / 2.0
308
- new_threshold = min(100.0, new_threshold) # Ensure not more than 100
309
- else:
310
- new_threshold = current_thr + THRESHOLD_STEP
311
- new_threshold = min(100.0 - THRESHOLD_BOUNDARY_PERCENT, new_threshold) # Don't step into deep half-way zone too soon
312
- run_classification(new_threshold)
313
-
314
-
315
- with col2:
316
- if st.button("Default", icon=":material/notes:", use_container_width=True, disabled=not st.session_state.output_html):
317
- run_classification(DEFAULT_THRESHOLD)
318
-
319
- with col3:
320
- if st.button("More", icon=":material/docs_add_on:", use_container_width=True, disabled=not st.session_state.output_html):
321
- current_thr = st.session_state.current_threshold
322
- if current_thr <= THRESHOLD_BOUNDARY_PERCENT: # In bottom 10% (0-10)
323
- new_threshold = current_thr / 2.0
324
- new_threshold = max(0.0, new_threshold) # Ensure not less than 0
325
- else:
326
- new_threshold = current_thr - THRESHOLD_STEP
327
- new_threshold = max(THRESHOLD_BOUNDARY_PERCENT, new_threshold) # Don't step into deep half-way zone too soon
328
- run_classification(new_threshold)