zstanjj commited on
Commit
0eeb7ec
1 Parent(s): 953999f

Upload modeling_phi3.py

Browse files
Files changed (1) hide show
  1. modeling_phi3.py +106 -4
modeling_phi3.py CHANGED
@@ -17,13 +17,10 @@
17
 
18
  import inspect
19
 
20
- import bs4
21
- import loguru
22
  import math
23
  import warnings
24
  from typing import List, Optional, Tuple, Union
25
 
26
- import numpy as np
27
  import torch
28
  import torch.nn.functional as F
29
  import torch.utils.checkpoint
@@ -50,7 +47,112 @@ from transformers.utils import (
50
  replace_return_docstrings,
51
  )
52
  from .configuration_phi3 import Phi3Config
53
- from .tree_gen_utils import split_tree, TokenIdNode, TokenDotExporter, nodenamefunc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  logger = logging.get_logger(__name__)
 
17
 
18
  import inspect
19
 
 
 
20
  import math
21
  import warnings
22
  from typing import List, Optional, Tuple, Union
23
 
 
24
  import torch
25
  import torch.nn.functional as F
26
  import torch.utils.checkpoint
 
47
  replace_return_docstrings,
48
  )
49
  from .configuration_phi3 import Phi3Config
50
+
51
+ from collections import defaultdict
52
+ from typing import List, Tuple
53
+
54
+ import numpy as np
55
+ from anytree import Node
56
+ import bs4
57
+ from anytree import PreOrderIter
58
+ from anytree.exporter import DotExporter
59
+
60
+
61
+ def nodenamefunc(node):
62
+ return f"{node.name}|{node.prob}|{node.input_ids}"
63
+
64
+
65
+ class TokenDotExporter(DotExporter):
66
+ def __init__(self, node, **kwargs):
67
+ super().__init__(node, **kwargs)
68
+
69
+ def __iter__(self):
70
+ # prepare
71
+ indent = " " * self.indent
72
+ nodenamefunc = self.nodenamefunc or self._default_nodenamefunc
73
+ nodeattrfunc = self.nodeattrfunc or self._default_nodeattrfunc
74
+ edgeattrfunc = self.edgeattrfunc or self._default_edgeattrfunc
75
+ edgetypefunc = self.edgetypefunc or self._default_edgetypefunc
76
+ filter_ = self.filter_ or self._default_filter
77
+ return self.__iter(indent, nodenamefunc, nodeattrfunc, edgeattrfunc, edgetypefunc, filter_)
78
+
79
+ def __iter_nodes(self, indent, nodenamefunc, nodeattrfunc, filter_):
80
+ for node in PreOrderIter(self.node, filter_=filter_, stop=self.stop, maxlevel=self.maxlevel):
81
+ nodename = nodenamefunc(node)
82
+ nodeattr = nodeattrfunc(node)
83
+ nodeattr = " {%s}" % nodeattr if nodeattr is not None else ""
84
+ yield '%s%s' % (DotExporter.esc(nodename), nodeattr)
85
+
86
+ def __iter(self, indent, nodenamefunc, nodeattrfunc, edgeattrfunc, edgetypefunc, filter_):
87
+ for node in self.__iter_nodes(indent, nodenamefunc, nodeattrfunc, filter_):
88
+ yield node
89
+
90
+
91
+ class TokenIdNode(Node):
92
+ def __init__(self, name, parent=None, children=None, **kwargs):
93
+ super().__init__(name, parent, children, **kwargs)
94
+ self.input_ids = kwargs.get('input_ids', [])
95
+ self.prob = kwargs.get('prob', np.float32(0.0))
96
+
97
+
98
+ def split_tree(soup: bs4.BeautifulSoup, max_node_words=0) -> List[Tuple[bs4.element.Tag, List[str], bool]]:
99
+ word_count = len(soup.get_text().split())
100
+ if word_count > max_node_words:
101
+ possible_trees = [(soup, [])]
102
+ target_trees = [] # [(tag, path, is_leaf)]
103
+ # split the entire dom tee into subtrees, until the length of the subtree is less than max_node_words words
104
+ # find all possible trees
105
+ while True:
106
+ if len(possible_trees) == 0:
107
+ break
108
+ tree = possible_trees.pop(0)
109
+ tag_children = defaultdict(int)
110
+ bare_word_count = 0
111
+ # count child tags
112
+ for child in tree[0].contents:
113
+ if isinstance(child, bs4.element.Tag):
114
+ tag_children[child.name] += 1
115
+ _tag_children = {k: 0 for k in tag_children.keys()}
116
+
117
+ # check if the tree can be split
118
+ for child in tree[0].contents:
119
+ if isinstance(child, bs4.element.Tag):
120
+ # change child tag with duplicate names
121
+ if tag_children[child.name] > 1:
122
+ new_name = f"{child.name}{_tag_children[child.name]}"
123
+ new_tree = (child, tree[1] + [new_name])
124
+ _tag_children[child.name] += 1
125
+ child.name = new_name
126
+ else:
127
+ new_tree = (child, tree[1] + [child.name])
128
+ word_count = len(child.get_text().split())
129
+ # add node with more than max_node_words words, and recursion depth is less than 64
130
+ if word_count > max_node_words and len(new_tree[1]) < 64:
131
+ possible_trees.append(new_tree)
132
+ else:
133
+ target_trees.append((new_tree[0], new_tree[1], True))
134
+ else:
135
+ bare_word_count += len(str(child).split())
136
+
137
+ # add leaf node
138
+ if len(tag_children) == 0:
139
+ target_trees.append((tree[0], tree[1], True))
140
+ # add node with more than max_node_words bare words
141
+ elif bare_word_count > max_node_words:
142
+ target_trees.append((tree[0], tree[1], False))
143
+ else:
144
+ soup_children = [c for c in soup.contents if isinstance(c, bs4.element.Tag)]
145
+ if len(soup_children) == 1:
146
+ target_trees = [(soup_children[0], [soup_children[0].name], True)]
147
+ else:
148
+ # add an html tag to wrap all children
149
+ new_soup = bs4.BeautifulSoup("", 'html.parser')
150
+ new_tag = new_soup.new_tag("html")
151
+ new_soup.append(new_tag)
152
+ for child in soup_children:
153
+ new_tag.append(child)
154
+ target_trees = [(new_tag, ["html"], True)]
155
+ return target_trees
156
 
157
 
158
  logger = logging.get_logger(__name__)