dejanseo commited on
Commit
a92f9e3
·
verified ·
1 Parent(s): df3962f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +39 -76
src/streamlit_app.py CHANGED
@@ -8,28 +8,21 @@ import trafilatura
8
  # Streamlit config
9
  st.set_page_config(layout="wide", page_title="LinkBERT")
10
 
11
- # Load tokenizer & model
12
  MODEL_ID = "dejanseo/LinkBERT-XL"
 
13
 
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
15
-
16
- # Determine the device
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
-
19
- # Load the model directly to the determined device
20
- # Avoid device_map="auto" if it's causing meta tensor issues with certain torch/transformers versions.
21
- # Load to CPU first, then move to GPU if available.
22
- model = AutoModelForTokenClassification.from_pretrained(MODEL_ID)
23
-
24
- # Explicitly move model to the determined device and dtype
25
- if device == "cuda":
26
- model.half().to(device) # Use .half() for float16 on GPU
27
- else:
28
- model.to(device) # For CPU, typically stick to float32 unless model was specifically trained with bfloat16 for CPU
29
-
30
  model.eval()
31
 
32
- # Functions (rest of your functions remain mostly the same)
33
  def tokenize_with_indices(text: str):
34
  encoded = tokenizer.encode_plus(
35
  text,
@@ -48,14 +41,15 @@ def fetch_and_extract_content(url: str):
48
  return None
49
 
50
  def process_text(inputs: str, confidence_threshold: float):
51
- max_chunk_length = 512 - 2 # safe window for special tokens
52
  words = inputs.split()
53
  chunk_texts = []
54
  current_chunk, current_length = [], 0
55
  for word in words:
56
  tok_len = len(tokenizer.tokenize(word))
57
  if tok_len + current_length > max_chunk_length:
58
- chunk_texts.append(" ".join(current_chunk))
 
59
  current_chunk = [word]
60
  current_length = tok_len
61
  else:
@@ -71,85 +65,61 @@ def process_text(inputs: str, confidence_threshold: float):
71
  with torch.no_grad():
72
  for chunk in chunk_texts:
73
  input_ids, token_offsets = tokenize_with_indices(chunk)
74
- # Ensure input_ids_tensor is on the same device as the model
75
- input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(model.device)
76
 
77
  outputs = model(input_ids_tensor)
78
  logits = outputs.logits # [1, seq_len, num_labels]
79
  predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist()
80
  softmax_scores = F.softmax(logits, dim=-1).squeeze(0).tolist()
81
 
82
- # The rest of your processing logic
83
  word_info = {}
84
  for idx, (start, end) in enumerate(token_offsets):
85
  if idx == 0 or idx == len(token_offsets) - 1:
86
- continue # skip specials
87
 
88
  word_start = start
89
- # Find the actual start of the word corresponding to this token
90
- # This logic assumes space-separated words for the most part
91
- while word_start > 0 and chunk[word_start - 1] not in [' ', '\n', '\t']:
92
  word_start -= 1
93
- # If a word_start maps to multiple tokens (e.g., "don't" -> ["don", "'", "t"])
94
- # ensure we pick the earliest start for that conceptual word
95
- while word_start > 0 and (chunk[word_start-1:word_start] == ' ' or tokenizer.decode(tokenizer.encode(chunk[word_start-1:end], add_special_tokens=False))[0] == chunk[word_start-1]):
96
- word_start -= 1
97
 
98
- # Use a tuple (word_start, actual_word_text_from_chunk) as key for more robust aggregation
99
- # For simplicity here, we stick to word_start
100
  if word_start not in word_info:
101
- # Initialize with default for "not link"
102
  word_info[word_start] = {"prediction": 0, "confidence": 0.0, "subtokens": []}
103
 
104
  conf_pct = softmax_scores[idx][predictions[idx]] * 100.0
105
-
106
- # Only mark as 1 if the current token's prediction is 1 AND confidence meets threshold
107
  if predictions[idx] == 1 and conf_pct >= confidence_threshold:
108
- word_info[word_start]["prediction"] = 1 # Mark the whole 'word' as a link
109
-
110
- # Keep the max confidence for any token within the 'word'
111
  word_info[word_start]["confidence"] = max(word_info[word_start]["confidence"], conf_pct)
112
  word_info[word_start]["subtokens"].append((start, end, chunk[start:end]))
113
 
114
  last_end = 0
115
- # Sort by word_start to maintain order
116
  for word_start in sorted(word_info.keys()):
117
  word_data = word_info[word_start]
118
- # Sort subtokens to ensure they are processed in order within a word
119
- for subtoken_start, subtoken_end, subtoken_text in sorted(word_data["subtokens"], key=lambda x: x[0]):
120
  escaped = subtoken_text.replace("$", "\\$")
121
- # Add any text between the last processed token and the current one
122
  if last_end < subtoken_start:
123
  reconstructed_text += chunk[last_end:subtoken_start]
124
-
125
  if word_data["prediction"] == 1:
126
- # Apply highlight to the subtoken
127
  reconstructed_text += (
128
- f"<span style='background-color: rgba(0, 255, 0, 0.5); display: inline;'>{escaped}</span>" # Added alpha for better readability
129
  )
130
  else:
131
- reconstructed_text += escaped # No highlight
132
-
133
  last_end = subtoken_end
134
 
135
- # For DataFrame, append the info for each *subtoken*
136
  df_data["Word"].append(escaped)
137
- df_data["Prediction"].append(word_data["prediction"]) # Prediction applies to the whole conceptual word
138
- df_data["Confidence"].append(word_data["confidence"]) # Confidence applies to the whole conceptual word
139
  df_data["Start"].append(subtoken_start + original_position_offset)
140
  df_data["End"].append(subtoken_end + original_position_offset)
141
 
142
- # Add any remaining text from the current chunk after the last token
143
- if last_end < len(chunk):
144
- reconstructed_text += chunk[last_end:].replace("$", "\\$")
145
-
146
- # Update offset for the next chunk. Add 1 for the space that was implicitly there.
147
- original_position_offset += len(chunk) + 1
148
 
149
  df_tokens = pd.DataFrame(df_data)
150
  return reconstructed_text, df_tokens
151
 
152
- # UI (remains the same)
153
  st.title("LinkBERT")
154
  st.markdown("""
155
  LinkBERT predicts natural link placement within web content. Enter text or a URL for extraction. Increase the threshold to reduce link predictions.
@@ -160,29 +130,22 @@ confidence_threshold = st.slider("Confidence Threshold", 50, 100, 50)
160
  tab1, tab2 = st.tabs(["Text Input", "URL Input"])
161
 
162
  with tab1:
163
- user_input = st.text_area("Enter text to process:", height=200) # Added height for better UX
164
  if st.button("Process Text"):
165
- if user_input: # Ensure input is not empty
166
- highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
167
- st.markdown(highlighted_text, unsafe_allow_html=True)
168
- st.dataframe(df_tokens)
169
- else:
170
- st.warning("Please enter some text to process.")
171
 
172
  with tab2:
173
- url_input = st.text_input("Enter URL to process:", value="https://dejan.ai/blog/gpt-5-made-seo-irreplaceable/") # Pre-fill with example
174
  if st.button("Fetch and Process"):
175
- if url_input: # Ensure URL input is not empty
176
- with st.spinner("Fetching and processing content..."):
177
- content = fetch_and_extract_content(url_input)
178
- if content:
179
- highlighted_text, df_tokens = process_text(content, confidence_threshold)
180
- st.markdown(highlighted_text, unsafe_allow_html=True)
181
- st.dataframe(df_tokens)
182
- else:
183
- st.error("Could not fetch content from the URL. Please check the URL and try again.")
184
  else:
185
- st.warning("Please enter a URL to process.")
186
 
187
  st.divider()
188
  st.markdown("""
@@ -202,4 +165,4 @@ LinkBERT was fine-tuned on a dataset of organic web content and editorial links.
202
  Interested in using this in an automated pipeline for bulk link prediction?
203
 
204
  Please [book an appointment](https://dejanmarketing.com/conference/).
205
- """)
 
8
  # Streamlit config
9
  st.set_page_config(layout="wide", page_title="LinkBERT")
10
 
11
+ # Model setup (load fully to avoid meta tensors)
12
  MODEL_ID = "dejanseo/LinkBERT-XL"
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
16
+ # Force materialized tensors on CPU, then move — avoids meta tensors
17
+ model = AutoModelForTokenClassification.from_pretrained(
18
+ MODEL_ID,
19
+ low_cpu_mem_usage=False, # important: materialize weights
20
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
21
+ )
22
+ model.to(device)
 
 
 
 
 
 
 
 
23
  model.eval()
24
 
25
+ # Functions
26
  def tokenize_with_indices(text: str):
27
  encoded = tokenizer.encode_plus(
28
  text,
 
41
  return None
42
 
43
  def process_text(inputs: str, confidence_threshold: float):
44
+ max_chunk_length = 512 - 2 # leave room for specials
45
  words = inputs.split()
46
  chunk_texts = []
47
  current_chunk, current_length = [], 0
48
  for word in words:
49
  tok_len = len(tokenizer.tokenize(word))
50
  if tok_len + current_length > max_chunk_length:
51
+ if current_chunk:
52
+ chunk_texts.append(" ".join(current_chunk))
53
  current_chunk = [word]
54
  current_length = tok_len
55
  else:
 
65
  with torch.no_grad():
66
  for chunk in chunk_texts:
67
  input_ids, token_offsets = tokenize_with_indices(chunk)
68
+ # Build tensors on correct device; no meta usage
69
+ input_ids_tensor = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
70
 
71
  outputs = model(input_ids_tensor)
72
  logits = outputs.logits # [1, seq_len, num_labels]
73
  predictions = torch.argmax(logits, dim=-1).squeeze(0).tolist()
74
  softmax_scores = F.softmax(logits, dim=-1).squeeze(0).tolist()
75
 
 
76
  word_info = {}
77
  for idx, (start, end) in enumerate(token_offsets):
78
  if idx == 0 or idx == len(token_offsets) - 1:
79
+ continue # skip special tokens
80
 
81
  word_start = start
82
+ while word_start > 0 and chunk[word_start - 1] != ' ':
 
 
83
  word_start -= 1
 
 
 
 
84
 
 
 
85
  if word_start not in word_info:
 
86
  word_info[word_start] = {"prediction": 0, "confidence": 0.0, "subtokens": []}
87
 
88
  conf_pct = softmax_scores[idx][predictions[idx]] * 100.0
 
 
89
  if predictions[idx] == 1 and conf_pct >= confidence_threshold:
90
+ word_info[word_start]["prediction"] = 1
 
 
91
  word_info[word_start]["confidence"] = max(word_info[word_start]["confidence"], conf_pct)
92
  word_info[word_start]["subtokens"].append((start, end, chunk[start:end]))
93
 
94
  last_end = 0
 
95
  for word_start in sorted(word_info.keys()):
96
  word_data = word_info[word_start]
97
+ for subtoken_start, subtoken_end, subtoken_text in word_data["subtokens"]:
 
98
  escaped = subtoken_text.replace("$", "\\$")
 
99
  if last_end < subtoken_start:
100
  reconstructed_text += chunk[last_end:subtoken_start]
 
101
  if word_data["prediction"] == 1:
 
102
  reconstructed_text += (
103
+ f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped}</span>"
104
  )
105
  else:
106
+ reconstructed_text += escaped
 
107
  last_end = subtoken_end
108
 
 
109
  df_data["Word"].append(escaped)
110
+ df_data["Prediction"].append(word_data["prediction"])
111
+ df_data["Confidence"].append(word_info[word_start]["confidence"])
112
  df_data["Start"].append(subtoken_start + original_position_offset)
113
  df_data["End"].append(subtoken_end + original_position_offset)
114
 
115
+ original_position_offset += len(chunk) + 1
116
+
117
+ reconstructed_text += chunk[last_end:].replace("$", "\\$")
 
 
 
118
 
119
  df_tokens = pd.DataFrame(df_data)
120
  return reconstructed_text, df_tokens
121
 
122
+ # UI
123
  st.title("LinkBERT")
124
  st.markdown("""
125
  LinkBERT predicts natural link placement within web content. Enter text or a URL for extraction. Increase the threshold to reduce link predictions.
 
130
  tab1, tab2 = st.tabs(["Text Input", "URL Input"])
131
 
132
  with tab1:
133
+ user_input = st.text_area("Enter text to process:")
134
  if st.button("Process Text"):
135
+ highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
136
+ st.markdown(highlighted_text, unsafe_allow_html=True)
137
+ st.dataframe(df_tokens)
 
 
 
138
 
139
  with tab2:
140
+ url_input = st.text_input("Enter URL to process:")
141
  if st.button("Fetch and Process"):
142
+ content = fetch_and_extract_content(url_input)
143
+ if content:
144
+ highlighted_text, df_tokens = process_text(content, confidence_threshold)
145
+ st.markdown(highlighted_text, unsafe_allow_html=True)
146
+ st.dataframe(df_tokens)
 
 
 
 
147
  else:
148
+ st.error("Could not fetch content from the URL. Please check the URL and try again.")
149
 
150
  st.divider()
151
  st.markdown("""
 
165
  Interested in using this in an automated pipeline for bulk link prediction?
166
 
167
  Please [book an appointment](https://dejanmarketing.com/conference/).
168
+ """)