kazalbrur commited on
Commit
6ac85e1
·
verified ·
1 Parent(s): 1bec043

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -33
app.py CHANGED
@@ -4,35 +4,16 @@ from transformers import pipeline
4
  from typing import List, Dict, Any
5
  import torch
6
 
7
- # Merging BIO-tagged tokens
8
- def merge_tokens(tokens: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
9
  merged_tokens = []
10
- current_entity = None
11
-
12
  for token in tokens:
13
- token_tag = token['entity']
14
-
15
- # If it's a beginning of a new entity (B- tag)
16
- if token_tag.startswith('B-'):
17
- current_entity = {
18
- 'word': token['word'],
19
- 'entity': token_tag[2:], # Removing the B- prefix
20
- 'start': token['start'],
21
- 'end': token['end'],
22
- 'score': token['score']
23
- }
24
- merged_tokens.append(current_entity)
25
-
26
- # If it's inside the current entity (I- tag) and the entity matches
27
- elif token_tag.startswith('I-') and current_entity and current_entity['entity'] == token_tag[2:]:
28
- current_entity['word'] += token['word'].replace('##', '')
29
- current_entity['end'] = token['end']
30
- current_entity['score'] = (current_entity['score'] + token['score']) / 2
31
-
32
- # In case of O or mismatched entities, we skip merging and handle separately
33
  else:
34
- current_entity = None
35
-
36
  return merged_tokens
37
 
38
  # Determine device
@@ -44,13 +25,8 @@ get_completion = pipeline("ner", model="kazalbrur/BanglaLegalNER", device=device
44
  @spaces.GPU(duration=120)
45
  def ner(input: str) -> Dict[str, Any]:
46
  try:
47
- # Get raw output from the NER model
48
  output = get_completion(input)
49
-
50
- # Merge tokens
51
  merged_tokens = merge_tokens(output)
52
-
53
- # Return the input text along with the merged entities
54
  return {"text": input, "entities": merged_tokens}
55
  except Exception as e:
56
  return {"text": input, "entities": [], "error": str(e)}
@@ -76,9 +52,9 @@ with demo:
76
  gr.Markdown(description)
77
  gr.Interface(
78
  fn=ner,
79
- inputs=[gr.Textbox(label="Enter Your Text to Find the Legal Entities", lines=30)],
80
  outputs=[gr.HighlightedText(label="Text with entities")],
81
  allow_flagging="never"
82
  )
83
 
84
- demo.launch()
 
4
  from typing import List, Dict, Any
5
  import torch
6
 
7
+ def merge_tokens(tokens: List[Dict[str, any]]) -> List[Dict[str, any]]:
 
8
  merged_tokens = []
 
 
9
  for token in tokens:
10
+ if merged_tokens and token['entity'].startswith('I-') and merged_tokens[-1]['entity'].endswith(token['entity'][2:]):
11
+ last_token = merged_tokens[-1]
12
+ last_token['word'] += token['word'].replace('##', '')
13
+ last_token['end'] = token['end']
14
+ last_token['score'] = (last_token['score'] + token['score']) / 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  else:
16
+ merged_tokens.append(token)
 
17
  return merged_tokens
18
 
19
  # Determine device
 
25
  @spaces.GPU(duration=120)
26
  def ner(input: str) -> Dict[str, Any]:
27
  try:
 
28
  output = get_completion(input)
 
 
29
  merged_tokens = merge_tokens(output)
 
 
30
  return {"text": input, "entities": merged_tokens}
31
  except Exception as e:
32
  return {"text": input, "entities": [], "error": str(e)}
 
52
  gr.Markdown(description)
53
  gr.Interface(
54
  fn=ner,
55
+ inputs=[gr.Textbox(label="Enter Your Text to Find the Legal Entities", lines=20)],
56
  outputs=[gr.HighlightedText(label="Text with entities")],
57
  allow_flagging="never"
58
  )
59
 
60
+ demo.launch()