umarigan commited on
Commit
1c9f94a
·
verified ·
1 Parent(s): 8ab4c34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -121
app.py CHANGED
@@ -5,93 +5,9 @@ from transformers import pipeline, AutoModelForTokenClassification, AutoTokenize
5
  import PyPDF2
6
  import docx
7
  import io
 
8
 
9
- def chunk_text(text, chunk_size=128):
10
- words = text.split()
11
- chunks = []
12
- current_chunk = []
13
- current_length = 0
14
-
15
- for word in words:
16
- if current_length + len(word) + 1 > chunk_size:
17
- chunks.append(' '.join(current_chunk))
18
- current_chunk = [word]
19
- current_length = len(word)
20
- else:
21
- current_chunk.append(word)
22
- current_length += len(word) + 1
23
-
24
- if current_chunk:
25
- chunks.append(' '.join(current_chunk))
26
-
27
- return chunks
28
-
29
- st.set_page_config(layout="wide")
30
-
31
- # Function to read text from uploaded file
32
- def read_file(file):
33
- if file.type == "text/plain":
34
- return file.getvalue().decode("utf-8")
35
- elif file.type == "application/pdf":
36
- pdf_reader = PyPDF2.PdfReader(io.BytesIO(file.getvalue()))
37
- return " ".join(page.extract_text() for page in pdf_reader.pages)
38
- elif file.type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document":
39
- doc = docx.Document(io.BytesIO(file.getvalue()))
40
- return " ".join(paragraph.text for paragraph in doc.paragraphs)
41
- else:
42
- st.error("Unsupported file type")
43
- return None
44
-
45
- st.title("Turkish NER Models Testing")
46
-
47
- model_list = [
48
- 'girayyagmur/bert-base-turkish-ner-cased',
49
- 'savasy/bert-base-turkish-ner-cased',
50
- 'xlm-roberta-large-finetuned-conll03-english',
51
- 'asahi417/tner-xlm-roberta-base-ontonotes5'
52
- ]
53
-
54
- st.sidebar.header("Select NER Model")
55
- model_checkpoint = st.sidebar.radio("", model_list)
56
-
57
- st.sidebar.write("For details of models: 'https://huggingface.co/akdeniz27/")
58
- st.sidebar.write("Only PDF, DOCX, and TXT files are supported.")
59
-
60
- # Determine aggregation strategy
61
- aggregation = "simple" if model_checkpoint in ["akdeniz27/xlm-roberta-base-turkish-ner", "xlm-roberta-large-finetuned-conll03-english", "asahi417/tner-xlm-roberta-base-ontonotes5"] else "first"
62
-
63
- st.subheader("Select Text Input Method")
64
- input_method = st.radio("", ('Write or Paste New Text', 'Upload File'))
65
-
66
- if input_method == "Write or Paste New Text":
67
- input_text = st.text_area('Write or Paste Text Below', value="", height=128)
68
- else:
69
- uploaded_file = st.file_uploader("Choose a file", type=["txt", "pdf", "docx"])
70
- if uploaded_file is not None:
71
- input_text = read_file(uploaded_file)
72
- if input_text:
73
- st.text_area("Extracted Text", input_text, height=128)
74
- else:
75
- input_text = ""
76
-
77
- @st.cache_resource
78
- def setModel(model_checkpoint, aggregation):
79
- model = AutoModelForTokenClassification.from_pretrained(model_checkpoint)
80
- tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
81
- return pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy=aggregation)
82
-
83
- @st.cache_resource
84
- def entity_comb(output):
85
- output_comb = []
86
- for ind, entity in enumerate(output):
87
- if ind == 0:
88
- output_comb.append(entity)
89
- elif output[ind]["start"] == output[ind-1]["end"] and output[ind]["entity_group"] == output[ind-1]["entity_group"]:
90
- output_comb[-1]["word"] += output[ind]["word"]
91
- output_comb[-1]["end"] = output[ind]["end"]
92
- else:
93
- output_comb.append(entity)
94
- return output_comb
95
 
96
  def create_mask_dict(entities):
97
  mask_dict = {}
@@ -105,14 +21,13 @@ def create_mask_dict(entities):
105
  entity_counters[entity['entity_group']] += 1
106
  mask_dict[entity['word']] = f"{entity['entity_group']}_{entity_counters[entity['entity_group']]}"
107
  return mask_dict
108
- def create_masked_text(input_text, entities, mask_dict):
 
109
  masked_text = input_text
110
- for entity in sorted(entities, key=lambda x: x['start'], reverse=True):
111
- if entity['entity_group'] not in ['CARDINAL', 'EVENT']:
112
- masked_text = masked_text[:entity['start']] + mask_dict[entity['word']] + masked_text[entity['end']:]
113
  return masked_text
114
 
115
-
116
  Run_Button = st.button("Run")
117
 
118
  if Run_Button and input_text:
@@ -134,47 +49,34 @@ if Run_Button and input_text:
134
  entity['end'] += offset
135
 
136
  all_outputs.extend(output)
137
-
138
 
139
  # Combine entities
140
-
141
  output_comb = entity_comb(all_outputs)
142
 
143
  # Create mask dictionary
144
  mask_dict = create_mask_dict(output_comb)
145
 
146
- masked_text = create_masked_text(input_text, output_comb, mask_dict)
147
-
148
- # Apply masking and add masked_word column
149
- for entity in output_comb:
150
- if entity['entity_group'] not in ['CARDINAL', 'EVENT']:
151
- entity['masked_word'] = mask_dict.get(entity['word'], entity['word'])
152
- else:
153
- entity['masked_word'] = entity['word']
154
- print("output_comb", output_comb)
155
- #df = pd.DataFrame.from_dict(output_comb)
156
- #cols_to_keep = ['word', 'entity_group', 'score', 'start', 'end']
157
- #df_final = df[cols_to_keep].loc[:,~df.columns.duplicated()].copy()
158
 
159
- #st.subheader("Recognized Entities")
160
- #st.dataframe(df_final)
161
 
162
-
163
-
164
- # Spacy display logic with entity numbering
 
 
 
 
 
 
 
165
  spacy_display = {"ents": [], "text": input_text, "title": None}
166
  for entity in output_comb:
167
  if entity['entity_group'] not in ['CARDINAL', 'EVENT']:
168
- label = f"{entity['entity_group']}_{mask_dict[entity['word']].split('_')[1]}"
169
- else:
170
- label = entity['entity_group']
171
- spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": label})
172
 
173
  html = spacy.displacy.render(spacy_display, style="ent", minify=True, manual=True)
174
- st.write(html, unsafe_allow_html=True)
175
-
176
- st.subheader("Masking Dictionary")
177
- st.json(mask_dict)
178
-
179
- st.subheader("Masked Text Preview")
180
- st.text(masked_text)
 
5
  import PyPDF2
6
  import docx
7
  import io
8
+ import re
9
 
10
+ # ... [Previous functions remain unchanged] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def create_mask_dict(entities):
13
  mask_dict = {}
 
21
  entity_counters[entity['entity_group']] += 1
22
  mask_dict[entity['word']] = f"{entity['entity_group']}_{entity_counters[entity['entity_group']]}"
23
  return mask_dict
24
+
25
+ def create_masked_text(input_text, mask_dict):
26
  masked_text = input_text
27
+ for word, mask in sorted(mask_dict.items(), key=lambda x: len(x[0]), reverse=True):
28
+ masked_text = re.sub(r'\b' + re.escape(word) + r'\b', mask, masked_text)
 
29
  return masked_text
30
 
 
31
  Run_Button = st.button("Run")
32
 
33
  if Run_Button and input_text:
 
49
  entity['end'] += offset
50
 
51
  all_outputs.extend(output)
 
52
 
53
  # Combine entities
 
54
  output_comb = entity_comb(all_outputs)
55
 
56
  # Create mask dictionary
57
  mask_dict = create_mask_dict(output_comb)
58
 
59
+ # Create masked text
60
+ masked_text = create_masked_text(input_text, mask_dict)
 
 
 
 
 
 
 
 
 
 
61
 
62
+ st.subheader("Masked Text")
63
+ st.text(masked_text)
64
 
65
+ st.subheader("Masking Dictionary")
66
+ st.json(mask_dict)
67
+
68
+ # Create a DataFrame for display
69
+ df = pd.DataFrame([(word, mask) for word, mask in mask_dict.items()], columns=['Original', 'Masked'])
70
+ st.subheader("Masking Table")
71
+ st.dataframe(df)
72
+
73
+ # Optional: Display original text with highlights
74
+ st.subheader("Original Text with Highlights")
75
  spacy_display = {"ents": [], "text": input_text, "title": None}
76
  for entity in output_comb:
77
  if entity['entity_group'] not in ['CARDINAL', 'EVENT']:
78
+ label = mask_dict[entity['word']]
79
+ spacy_display["ents"].append({"start": entity["start"], "end": entity["end"], "label": label})
 
 
80
 
81
  html = spacy.displacy.render(spacy_display, style="ent", minify=True, manual=True)
82
+ st.write(html, unsafe_allow_html=True)