KevinHuSh commited on
Commit
62e78ef
·
1 Parent(s): 058cd84

rename vision, add layour and tsr recognizer (#70)

Browse files

* rename vision, add layour and tsr recognizer

* trivial fixing

api/apps/conversation_app.py CHANGED
@@ -34,7 +34,6 @@ from rag.utils import num_tokens_from_string, encoder, rmSpace
34
 
35
  @manager.route('/set', methods=['POST'])
36
  @login_required
37
- @validate_request("dialog_id")
38
  def set_conversation():
39
  req = request.json
40
  conv_id = req.get("conversation_id")
@@ -145,7 +144,7 @@ def message_fit_in(msg, max_length=4000):
145
 
146
  @manager.route('/completion', methods=['POST'])
147
  @login_required
148
- @validate_request("dialog_id", "messages")
149
  def completion():
150
  req = request.json
151
  msg = []
@@ -154,12 +153,20 @@ def completion():
154
  if m["role"] == "assistant" and not msg: continue
155
  msg.append({"role": m["role"], "content": m["content"]})
156
  try:
157
- e, dia = DialogService.get_by_id(req["dialog_id"])
 
 
 
 
158
  if not e:
159
  return get_data_error_result(retmsg="Dialog not found!")
160
- del req["dialog_id"]
161
  del req["messages"]
162
- return get_json_result(data=chat(dia, msg, **req))
 
 
 
 
163
  except Exception as e:
164
  return server_error_response(e)
165
 
@@ -194,8 +201,8 @@ def chat(dialog, messages, **kwargs):
194
  dialog.vector_similarity_weight, top=1024, aggs=False)
195
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
196
 
197
- if not knowledges and prompt_config["empty_response"]:
198
- return {"answer": prompt_config["empty_response"], "retrieval": kbinfos}
199
 
200
  kwargs["knowledge"] = "\n".join(knowledges)
201
  gen_conf = dialog.llm_setting
@@ -205,7 +212,8 @@ def chat(dialog, messages, **kwargs):
205
  gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
206
  answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
207
 
208
- answer = retrievaler.insert_citations(answer,
 
209
  [ck["content_ltks"] for ck in kbinfos["chunks"]],
210
  [ck["vector"] for ck in kbinfos["chunks"]],
211
  embd_mdl,
@@ -213,7 +221,7 @@ def chat(dialog, messages, **kwargs):
213
  vtweight=dialog.vector_similarity_weight)
214
  for c in kbinfos["chunks"]:
215
  if c.get("vector"): del c["vector"]
216
- return {"answer": answer, "retrieval": kbinfos}
217
 
218
 
219
  def use_sql(question, field_map, tenant_id, chat_mdl):
 
34
 
35
  @manager.route('/set', methods=['POST'])
36
  @login_required
 
37
  def set_conversation():
38
  req = request.json
39
  conv_id = req.get("conversation_id")
 
144
 
145
  @manager.route('/completion', methods=['POST'])
146
  @login_required
147
+ @validate_request("conversation_id", "messages")
148
  def completion():
149
  req = request.json
150
  msg = []
 
153
  if m["role"] == "assistant" and not msg: continue
154
  msg.append({"role": m["role"], "content": m["content"]})
155
  try:
156
+ e, conv = ConversationService.get_by_id(req["conversation_id"])
157
+ if not e:
158
+ return get_data_error_result(retmsg="Conversation not found!")
159
+ conv.message.append(msg[-1])
160
+ e, dia = DialogService.get_by_id(conv.dialog_id)
161
  if not e:
162
  return get_data_error_result(retmsg="Dialog not found!")
163
+ del req["conversation_id"]
164
  del req["messages"]
165
+ ans = chat(dia, msg, **req)
166
+ conv.reference.append(ans["reference"])
167
+ conv.message.append({"role": "assistant", "content": ans["answer"]})
168
+ ConversationService.update_by_id(conv.id, conv.to_dict())
169
+ return get_json_result(data=ans)
170
  except Exception as e:
171
  return server_error_response(e)
172
 
 
201
  dialog.vector_similarity_weight, top=1024, aggs=False)
202
  knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
203
 
204
+ if not knowledges and prompt_config.get("empty_response"):
205
+ return {"answer": prompt_config["empty_response"], "reference": kbinfos}
206
 
207
  kwargs["knowledge"] = "\n".join(knowledges)
208
  gen_conf = dialog.llm_setting
 
212
  gen_conf["max_tokens"] = min(gen_conf["max_tokens"], llm.max_tokens - used_token_count)
213
  answer = chat_mdl.chat(prompt_config["system"].format(**kwargs), msg, gen_conf)
214
 
215
+ if knowledges:
216
+ answer = retrievaler.insert_citations(answer,
217
  [ck["content_ltks"] for ck in kbinfos["chunks"]],
218
  [ck["vector"] for ck in kbinfos["chunks"]],
219
  embd_mdl,
 
221
  vtweight=dialog.vector_similarity_weight)
222
  for c in kbinfos["chunks"]:
223
  if c.get("vector"): del c["vector"]
224
+ return {"answer": answer, "reference": kbinfos}
225
 
226
 
227
  def use_sql(question, field_map, tenant_id, chat_mdl):
api/apps/llm_app.py CHANGED
@@ -94,11 +94,11 @@ def list():
94
  model_type = request.args.get("model_type")
95
  try:
96
  objs = TenantLLMService.query(tenant_id=current_user.id)
97
- mdlnms = set([o.to_dict()["llm_name"] for o in objs if o.api_key])
98
  llms = LLMService.get_all()
99
  llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
100
  for m in llms:
101
- m["available"] = m["llm_name"] in mdlnms
102
 
103
  res = {}
104
  for m in llms:
 
94
  model_type = request.args.get("model_type")
95
  try:
96
  objs = TenantLLMService.query(tenant_id=current_user.id)
97
+ facts = set([o.to_dict()["llm_factory"] for o in objs if o.api_key])
98
  llms = LLMService.get_all()
99
  llms = [m.to_dict() for m in llms if m.status == StatusEnum.VALID.value]
100
  for m in llms:
101
+ m["available"] = m["fid"] in facts
102
 
103
  res = {}
104
  for m in llms:
api/db/db_models.py CHANGED
@@ -500,7 +500,7 @@ class Document(DataBaseModel):
500
  token_num = IntegerField(default=0)
501
  chunk_num = IntegerField(default=0)
502
  progress = FloatField(default=0)
503
- progress_msg = CharField(max_length=512, null=True, help_text="process message", default="")
504
  process_begin_at = DateTimeField(null=True)
505
  process_duation = FloatField(default=0)
506
  run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
@@ -518,7 +518,7 @@ class Task(DataBaseModel):
518
  begin_at = DateTimeField(null=True)
519
  process_duation = FloatField(default=0)
520
  progress = FloatField(default=0)
521
- progress_msg = CharField(max_length=255, null=True, help_text="process message", default="")
522
 
523
 
524
  class Dialog(DataBaseModel):
@@ -561,6 +561,7 @@ class Conversation(DataBaseModel):
561
  dialog_id = CharField(max_length=32, null=False, index=True)
562
  name = CharField(max_length=255, null=True, help_text="converastion name")
563
  message = JSONField(null=True)
 
564
 
565
  class Meta:
566
  db_table = "conversation"
 
500
  token_num = IntegerField(default=0)
501
  chunk_num = IntegerField(default=0)
502
  progress = FloatField(default=0)
503
+ progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
504
  process_begin_at = DateTimeField(null=True)
505
  process_duation = FloatField(default=0)
506
  run = CharField(max_length=1, null=True, help_text="start to run processing or cancel.(1: run it; 2: cancel)", default="0")
 
518
  begin_at = DateTimeField(null=True)
519
  process_duation = FloatField(default=0)
520
  progress = FloatField(default=0)
521
+ progress_msg = CharField(max_length=4096, null=True, help_text="process message", default="")
522
 
523
 
524
  class Dialog(DataBaseModel):
 
561
  dialog_id = CharField(max_length=32, null=False, index=True)
562
  name = CharField(max_length=255, null=True, help_text="converastion name")
563
  message = JSONField(null=True)
564
+ reference = JSONField(null=True, default=[])
565
 
566
  class Meta:
567
  db_table = "conversation"
api/db/services/llm_service.py CHANGED
@@ -75,7 +75,7 @@ class TenantLLMService(CommonService):
75
 
76
  model_config = cls.get_api_key(tenant_id, mdlnm)
77
  if not model_config:
78
- raise LookupError("Model({}) not found".format(mdlnm))
79
  model_config = model_config.to_dict()
80
  if llm_type == LLMType.EMBEDDING.value:
81
  if model_config["llm_factory"] not in EmbeddingModel:
 
75
 
76
  model_config = cls.get_api_key(tenant_id, mdlnm)
77
  if not model_config:
78
+ raise LookupError("Model({}) not authorized".format(mdlnm))
79
  model_config = model_config.to_dict()
80
  if llm_type == LLMType.EMBEDDING.value:
81
  if model_config["llm_factory"] not in EmbeddingModel:
deepdoc/parser/pdf_parser.py CHANGED
@@ -1,9 +1,7 @@
1
  # -*- coding: utf-8 -*-
2
- import os
3
  import random
4
 
5
  import fitz
6
- import requests
7
  import xgboost as xgb
8
  from io import BytesIO
9
  import torch
@@ -14,9 +12,8 @@ from PIL import Image
14
  import numpy as np
15
 
16
  from api.db import ParserType
17
- from deepdoc.visual import OCR, Recognizer
18
  from rag.nlp import huqie
19
- from collections import Counter
20
  from copy import deepcopy
21
  from huggingface_hub import hf_hub_download
22
 
@@ -29,29 +26,8 @@ class HuParser:
29
  self.ocr = OCR()
30
  if not hasattr(self, "model_speciess"):
31
  self.model_speciess = ParserType.GENERAL.value
32
- self.layout_labels = [
33
- "_background_",
34
- "Text",
35
- "Title",
36
- "Figure",
37
- "Figure caption",
38
- "Table",
39
- "Table caption",
40
- "Header",
41
- "Footer",
42
- "Reference",
43
- "Equation",
44
- ]
45
- self.tsr_labels = [
46
- "table",
47
- "table column",
48
- "table row",
49
- "table column header",
50
- "table projected row header",
51
- "table spanning cell",
52
- ]
53
- self.layouter = Recognizer(self.layout_labels, "layout", "/data/newpeak/medical-gpt/res/ppdet/")
54
- self.tbl_det = Recognizer(self.tsr_labels, "tsr", "/data/newpeak/medical-gpt/res/ppdet.tbl/")
55
 
56
  self.updown_cnt_mdl = xgb.Booster()
57
  if torch.cuda.is_available():
@@ -70,39 +46,6 @@ class HuParser:
70
 
71
  """
72
 
73
- def __remote_call(self, species, images, thr=0.7):
74
- url = os.environ.get("INFINIFLOW_SERVER")
75
- token = os.environ.get("INFINIFLOW_TOKEN")
76
- if not url or not token:
77
- logging.warning("INFINIFLOW_SERVER is not specified. To maximize the effectiveness, please visit https://github.com/infiniflow/ragflow, and sign in the our demo web site to get token. It's FREE! Using 'export' to set both environment variables: INFINIFLOW_SERVER and INFINIFLOW_TOKEN.")
78
- return [[] for _ in range(len(images))]
79
-
80
- def convert_image_to_bytes(PILimage):
81
- image = BytesIO()
82
- PILimage.save(image, format='png')
83
- image.seek(0)
84
- return image.getvalue()
85
-
86
- images = [convert_image_to_bytes(img) for img in images]
87
-
88
- def remote_call():
89
- nonlocal images, thr
90
- res = requests.post(url+"/v1/layout/detect/"+species, files=[("image", img) for img in images], data={"threashold": thr},
91
- headers={"Authorization": token}, timeout=len(images) * 10)
92
- res = res.json()
93
- if res["retcode"] != 0: raise RuntimeError(res["retmsg"])
94
- return res["data"]
95
-
96
- for _ in range(3):
97
- try:
98
- return remote_call()
99
- except RuntimeError as e:
100
- raise e
101
- except Exception as e:
102
- logging.error("layout_predict:"+str(e))
103
- return remote_call()
104
-
105
-
106
  def __char_width(self, c):
107
  return (c["x1"] - c["x0"]) // len(c["text"])
108
 
@@ -188,20 +131,6 @@ class HuParser:
188
  ]
189
  return fea
190
 
191
- @staticmethod
192
- def sort_Y_firstly(arr, threashold):
193
- # sort using y1 first and then x1
194
- arr = sorted(arr, key=lambda r: (r["top"], r["x0"]))
195
- for i in range(len(arr) - 1):
196
- for j in range(i, -1, -1):
197
- # restore the order using th
198
- if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \
199
- and arr[j + 1]["x0"] < arr[j]["x0"]:
200
- tmp = deepcopy(arr[j])
201
- arr[j] = deepcopy(arr[j + 1])
202
- arr[j + 1] = deepcopy(tmp)
203
- return arr
204
-
205
  @staticmethod
206
  def sort_X_by_page(arr, threashold):
207
  # sort using y1 first and then x1
@@ -217,61 +146,6 @@ class HuParser:
217
  arr[j + 1] = tmp
218
  return arr
219
 
220
- @staticmethod
221
- def sort_R_firstly(arr, thr=0):
222
- # sort using y1 first and then x1
223
- # sorted(arr, key=lambda r: (r["top"], r["x0"]))
224
- arr = HuParser.sort_Y_firstly(arr, thr)
225
- for i in range(len(arr) - 1):
226
- for j in range(i, -1, -1):
227
- if "R" not in arr[j] or "R" not in arr[j + 1]:
228
- continue
229
- if arr[j + 1]["R"] < arr[j]["R"] \
230
- or (
231
- arr[j + 1]["R"] == arr[j]["R"]
232
- and arr[j + 1]["x0"] < arr[j]["x0"]
233
- ):
234
- tmp = arr[j]
235
- arr[j] = arr[j + 1]
236
- arr[j + 1] = tmp
237
- return arr
238
-
239
- @staticmethod
240
- def sort_X_firstly(arr, threashold, copy=True):
241
- # sort using y1 first and then x1
242
- arr = sorted(arr, key=lambda r: (r["x0"], r["top"]))
243
- for i in range(len(arr) - 1):
244
- for j in range(i, -1, -1):
245
- # restore the order using th
246
- if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \
247
- and arr[j + 1]["top"] < arr[j]["top"]:
248
- tmp = deepcopy(arr[j]) if copy else arr[j]
249
- arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1]
250
- arr[j + 1] = deepcopy(tmp) if copy else tmp
251
- return arr
252
-
253
- @staticmethod
254
- def sort_C_firstly(arr, thr=0):
255
- # sort using y1 first and then x1
256
- # sorted(arr, key=lambda r: (r["x0"], r["top"]))
257
- arr = HuParser.sort_X_firstly(arr, thr)
258
- for i in range(len(arr) - 1):
259
- for j in range(i, -1, -1):
260
- # restore the order using th
261
- if "C" not in arr[j] or "C" not in arr[j + 1]:
262
- continue
263
- if arr[j + 1]["C"] < arr[j]["C"] \
264
- or (
265
- arr[j + 1]["C"] == arr[j]["C"]
266
- and arr[j + 1]["top"] < arr[j]["top"]
267
- ):
268
- tmp = arr[j]
269
- arr[j] = arr[j + 1]
270
- arr[j + 1] = tmp
271
- return arr
272
-
273
- return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"]))
274
-
275
  def _has_color(self, o):
276
  if o.get("ncs", "") == "DeviceGray":
277
  if o["stroking_color"] and o["stroking_color"][0] == 1 and o["non_stroking_color"] and \
@@ -280,172 +154,6 @@ class HuParser:
280
  return False
281
  return True
282
 
283
- def __overlapped_area(self, a, b, ratio=True):
284
- tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
285
- if b["x0"] > x1 or b["x1"] < x0:
286
- return 0
287
- if b["bottom"] < tp or b["top"] > btm:
288
- return 0
289
- x0_ = max(b["x0"], x0)
290
- x1_ = min(b["x1"], x1)
291
- assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format(
292
- tp, btm, x0, x1, b)
293
- tp_ = max(b["top"], tp)
294
- btm_ = min(b["bottom"], btm)
295
- assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
296
- tp, btm, x0, x1, b)
297
- ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
298
- x0 != 0 and btm - tp != 0 else 0
299
- if ov > 0 and ratio:
300
- ov /= (x1 - x0) * (btm - tp)
301
- return ov
302
-
303
- def __find_overlapped_with_threashold(self, box, boxes, thr=0.3):
304
- if not boxes:
305
- return
306
- max_overlaped_i, max_overlaped, _max_overlaped = None, thr, 0
307
- s, e = 0, len(boxes)
308
- for i in range(s, e):
309
- ov = self.__overlapped_area(box, boxes[i])
310
- _ov = self.__overlapped_area(boxes[i], box)
311
- if (ov, _ov) < (max_overlaped, _max_overlaped):
312
- continue
313
- max_overlaped_i = i
314
- max_overlaped = ov
315
- _max_overlaped = _ov
316
-
317
- return max_overlaped_i
318
-
319
- def __find_overlapped(self, box, boxes_sorted_by_y, naive=False):
320
- if not boxes_sorted_by_y:
321
- return
322
- bxs = boxes_sorted_by_y
323
- s, e, ii = 0, len(bxs), 0
324
- while s < e and not naive:
325
- ii = (e + s) // 2
326
- pv = bxs[ii]
327
- if box["bottom"] < pv["top"]:
328
- e = ii
329
- continue
330
- if box["top"] > pv["bottom"]:
331
- s = ii + 1
332
- continue
333
- break
334
- while s < ii:
335
- if box["top"] > bxs[s]["bottom"]:
336
- s += 1
337
- break
338
- while e - 1 > ii:
339
- if box["bottom"] < bxs[e - 1]["top"]:
340
- e -= 1
341
- break
342
-
343
- max_overlaped_i, max_overlaped = None, 0
344
- for i in range(s, e):
345
- ov = self.__overlapped_area(bxs[i], box)
346
- if ov <= max_overlaped:
347
- continue
348
- max_overlaped_i = i
349
- max_overlaped = ov
350
-
351
- return max_overlaped_i
352
-
353
- def _is_garbage(self, b):
354
- patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
355
- r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
356
- "(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
357
- "\\(cid *: *[0-9]+ *\\)"
358
- ]
359
- return any([re.search(p, b["text"]) for p in patt])
360
-
361
- def __layouts_cleanup(self, boxes, layouts, far=2, thr=0.7):
362
- def notOverlapped(a, b):
363
- return any([a["x1"] < b["x0"],
364
- a["x0"] > b["x1"],
365
- a["bottom"] < b["top"],
366
- a["top"] > b["bottom"]])
367
-
368
- i = 0
369
- while i + 1 < len(layouts):
370
- j = i + 1
371
- while j < min(i + far, len(layouts)) \
372
- and (layouts[i].get("type", "") != layouts[j].get("type", "")
373
- or notOverlapped(layouts[i], layouts[j])):
374
- j += 1
375
- if j >= min(i + far, len(layouts)):
376
- i += 1
377
- continue
378
- if self.__overlapped_area(layouts[i], layouts[j]) < thr \
379
- and self.__overlapped_area(layouts[j], layouts[i]) < thr:
380
- i += 1
381
- continue
382
-
383
- if layouts[i].get("score") and layouts[j].get("score"):
384
- if layouts[i]["score"] > layouts[j]["score"]:
385
- layouts.pop(j)
386
- else:
387
- layouts.pop(i)
388
- continue
389
-
390
- area_i, area_i_1 = 0, 0
391
- for b in boxes:
392
- if not notOverlapped(b, layouts[i]):
393
- area_i += self.__overlapped_area(b, layouts[i], False)
394
- if not notOverlapped(b, layouts[j]):
395
- area_i_1 += self.__overlapped_area(b, layouts[j], False)
396
-
397
- if area_i > area_i_1:
398
- layouts.pop(j)
399
- else:
400
- layouts.pop(i)
401
-
402
- return layouts
403
-
404
- def __table_tsr(self, images):
405
- tbls = self.tbl_det(images, thr=0.5)
406
- res = []
407
- # align left&right for rows, align top&bottom for columns
408
- for tbl in tbls:
409
- lts = [{"label": b["type"],
410
- "score": b["score"],
411
- "x0": b["bbox"][0], "x1": b["bbox"][2],
412
- "top": b["bbox"][1], "bottom": b["bbox"][-1]
413
- } for b in tbl]
414
- if not lts:
415
- continue
416
-
417
- left = [b["x0"] for b in lts if b["label"].find(
418
- "row") > 0 or b["label"].find("header") > 0]
419
- right = [b["x1"] for b in lts if b["label"].find(
420
- "row") > 0 or b["label"].find("header") > 0]
421
- if not left:
422
- continue
423
- left = np.median(left) if len(left) > 4 else np.min(left)
424
- right = np.median(right) if len(right) > 4 else np.max(right)
425
- for b in lts:
426
- if b["label"].find("row") > 0 or b["label"].find("header") > 0:
427
- if b["x0"] > left:
428
- b["x0"] = left
429
- if b["x1"] < right:
430
- b["x1"] = right
431
-
432
- top = [b["top"] for b in lts if b["label"] == "table column"]
433
- bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
434
- if not top:
435
- res.append(lts)
436
- continue
437
- top = np.median(top) if len(top) > 4 else np.min(top)
438
- bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
439
- for b in lts:
440
- if b["label"] == "table column":
441
- if b["top"] > top:
442
- b["top"] = top
443
- if b["bottom"] < bottom:
444
- b["bottom"] = bottom
445
-
446
- res.append(lts)
447
- return res
448
-
449
  def _table_transformer_job(self, ZM):
450
  logging.info("Table processing...")
451
  imgs, pos = [], []
@@ -471,7 +179,7 @@ class HuParser:
471
  assert len(self.page_images) == len(tbcnt) - 1
472
  if not imgs:
473
  return
474
- recos = self.__table_tsr(imgs)
475
  tbcnt = np.cumsum(tbcnt)
476
  for i in range(len(tbcnt) - 1): # for page
477
  pg = []
@@ -493,10 +201,10 @@ class HuParser:
493
  self.tb_cpns.extend(pg)
494
 
495
  def gather(kwd, fzy=10, ption=0.6):
496
- eles = self.sort_Y_firstly(
497
  [r for r in self.tb_cpns if re.match(kwd, r["label"])], fzy)
498
- eles = self.__layouts_cleanup(self.boxes, eles, 5, ption)
499
- return self.sort_Y_firstly(eles, 0)
500
 
501
  # add R,H,C,SP tag to boxes within table layout
502
  headers = gather(r".*header$")
@@ -504,17 +212,17 @@ class HuParser:
504
  spans = gather(r".*spanning")
505
  clmns = sorted([r for r in self.tb_cpns if re.match(
506
  r"table column$", r["label"])], key=lambda x: (x["pn"], x["layoutno"], x["x0"]))
507
- clmns = self.__layouts_cleanup(self.boxes, clmns, 5, 0.5)
508
  for b in self.boxes:
509
  if b.get("layout_type", "") != "table":
510
  continue
511
- ii = self.__find_overlapped_with_threashold(b, rows, thr=0.3)
512
  if ii is not None:
513
  b["R"] = ii
514
  b["R_top"] = rows[ii]["top"]
515
  b["R_bott"] = rows[ii]["bottom"]
516
 
517
- ii = self.__find_overlapped_with_threashold(b, headers, thr=0.3)
518
  if ii is not None:
519
  b["H_top"] = headers[ii]["top"]
520
  b["H_bott"] = headers[ii]["bottom"]
@@ -522,13 +230,13 @@ class HuParser:
522
  b["H_right"] = headers[ii]["x1"]
523
  b["H"] = ii
524
 
525
- ii = self.__find_overlapped_with_threashold(b, clmns, thr=0.3)
526
  if ii is not None:
527
  b["C"] = ii
528
  b["C_left"] = clmns[ii]["x0"]
529
  b["C_right"] = clmns[ii]["x1"]
530
 
531
- ii = self.__find_overlapped_with_threashold(b, spans, thr=0.3)
532
  if ii is not None:
533
  b["H_top"] = spans[ii]["top"]
534
  b["H_bott"] = spans[ii]["bottom"]
@@ -542,7 +250,7 @@ class HuParser:
542
  self.boxes.append([])
543
  return
544
  bxs = [(line[0], line[1][0]) for line in bxs]
545
- bxs = self.sort_Y_firstly(
546
  [{"x0": b[0][0] / ZM, "x1": b[1][0] / ZM,
547
  "top": b[0][1] / ZM, "text": "", "txt": t,
548
  "bottom": b[-1][1] / ZM,
@@ -551,8 +259,8 @@ class HuParser:
551
  )
552
 
553
  # merge chars in the same rect
554
- for c in self.sort_X_firstly(chars, self.mean_width[pagenum - 1] // 4):
555
- ii = self.__find_overlapped(c, bxs)
556
  if ii is None:
557
  self.lefted_chars.append(c)
558
  continue
@@ -573,91 +281,11 @@ class HuParser:
573
  if self.mean_height[-1] == 0:
574
  self.mean_height[-1] = np.median([b["bottom"] - b["top"]
575
  for b in bxs])
576
-
577
  self.boxes.append(bxs)
578
 
579
  def _layouts_rec(self, ZM):
580
  assert len(self.page_images) == len(self.boxes)
581
- # Tag layout type
582
- boxes = []
583
- layouts = self.layouter(self.page_images)
584
- #save_results(self.page_images, layouts, self.layout_labels, output_dir='output/', threshold=0.7)
585
- assert len(self.page_images) == len(layouts)
586
- for pn, lts in enumerate(layouts):
587
- bxs = self.boxes[pn]
588
- lts = [{"type": b["type"],
589
- "score": float(b["score"]),
590
- "x0": b["bbox"][0] / ZM, "x1": b["bbox"][2] / ZM,
591
- "top": b["bbox"][1] / ZM, "bottom": b["bbox"][-1] / ZM,
592
- "page_number": pn,
593
- } for b in lts]
594
- lts = self.sort_Y_firstly(lts, self.mean_height[pn] / 2)
595
- lts = self.__layouts_cleanup(bxs, lts)
596
- self.page_layout.append(lts)
597
-
598
- # Tag layout type, layouts are ready
599
- def findLayout(ty):
600
- nonlocal bxs, lts
601
- lts_ = [lt for lt in lts if lt["type"] == ty]
602
- i = 0
603
- while i < len(bxs):
604
- if bxs[i].get("layout_type"):
605
- i += 1
606
- continue
607
- if self._is_garbage(bxs[i]):
608
- logging.debug("GARBAGE: " + bxs[i]["text"])
609
- bxs.pop(i)
610
- continue
611
-
612
- ii = self.__find_overlapped_with_threashold(bxs[i], lts_,
613
- thr=0.4)
614
- if ii is None: # belong to nothing
615
- bxs[i]["layout_type"] = ""
616
- i += 1
617
- continue
618
- lts_[ii]["visited"] = True
619
- if lts_[ii]["type"] in ["footer", "header", "reference"]:
620
- if lts_[ii]["type"] not in self.garbages:
621
- self.garbages[lts_[ii]["type"]] = []
622
- self.garbages[lts_[ii]["type"]].append(bxs[i]["text"])
623
- logging.debug("GARBAGE: " + bxs[i]["text"])
624
- bxs.pop(i)
625
- continue
626
-
627
- bxs[i]["layoutno"] = f"{ty}-{ii}"
628
- bxs[i]["layout_type"] = lts_[ii]["type"]
629
- i += 1
630
-
631
- for lt in ["footer", "header", "reference", "figure caption",
632
- "table caption", "title", "text", "table", "figure"]:
633
- findLayout(lt)
634
-
635
- # add box to figure layouts which has not text box
636
- for i, lt in enumerate(
637
- [lt for lt in lts if lt["type"] == "figure"]):
638
- if lt.get("visited"):
639
- continue
640
- lt = deepcopy(lt)
641
- del lt["type"]
642
- lt["text"] = ""
643
- lt["layout_type"] = "figure"
644
- lt["layoutno"] = f"figure-{i}"
645
- bxs.append(lt)
646
-
647
- boxes.extend(bxs)
648
-
649
- self.boxes = boxes
650
-
651
- garbage = set()
652
- for k in self.garbages.keys():
653
- self.garbages[k] = Counter(self.garbages[k])
654
- for g, c in self.garbages[k].items():
655
- if c > 1:
656
- garbage.add(g)
657
-
658
- logging.debug("GARBAGE:" + ",".join(garbage))
659
- self.boxes = [b for b in self.boxes if b["text"].strip() not in garbage]
660
-
661
  # cumlative Y
662
  for i in range(len(self.boxes)):
663
  self.boxes[i]["top"] += \
@@ -710,7 +338,7 @@ class HuParser:
710
  self.boxes = bxs
711
 
712
  def _naive_vertical_merge(self):
713
- bxs = self.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3)
714
  i = 0
715
  while i + 1 < len(bxs):
716
  b = bxs[i]
@@ -850,7 +478,7 @@ class HuParser:
850
  t["layout_type"] = c["layout_type"]
851
  boxes.append(t)
852
 
853
- self.boxes = self.sort_Y_firstly(boxes, 0)
854
 
855
  def _filter_forpages(self):
856
  if not self.boxes:
@@ -916,492 +544,6 @@ class HuParser:
916
  b_["top"] = b["top"]
917
  self.boxes.pop(i)
918
 
919
- def _blockType(self, b):
920
- patt = [
921
- ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
922
- (r"^(20|19)[0-9]{2}年$", "Dt"),
923
- (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
924
- ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
925
- (r"^第*[一二三四1-4]季度$", "Dt"),
926
- (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
927
- (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
928
- ("^[0-9.,+%/ -]+$", "Nu"),
929
- (r"^[0-9A-Z/\._~-]+$", "Ca"),
930
- (r"^[A-Z]*[a-z' -]+$", "En"),
931
- (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
932
- (r"^.{1}$", "Sg")
933
- ]
934
- for p, n in patt:
935
- if re.search(p, b["text"].strip()):
936
- return n
937
- tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1]
938
- if len(tks) > 3:
939
- if len(tks) < 12:
940
- return "Tx"
941
- else:
942
- return "Lx"
943
-
944
- if len(tks) == 1 and huqie.tag(tks[0]) == "nr":
945
- return "Nr"
946
-
947
- return "Ot"
948
-
949
- def __cal_spans(self, boxes, rows, cols, tbl, html=True):
950
- # caculate span
951
- clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
952
- for cln in cols]
953
- crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln])
954
- for cln in cols]
955
- rtop = [np.mean([c.get("R_top", c["top"]) for c in row])
956
- for row in rows]
957
- rbtm = [np.mean([c.get("R_btm", c["bottom"])
958
- for c in row]) for row in rows]
959
- for b in boxes:
960
- if "SP" not in b:
961
- continue
962
- b["colspan"] = [b["cn"]]
963
- b["rowspan"] = [b["rn"]]
964
- # col span
965
- for j in range(0, len(clft)):
966
- if j == b["cn"]:
967
- continue
968
- if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
969
- continue
970
- if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
971
- continue
972
- b["colspan"].append(j)
973
- # row span
974
- for j in range(0, len(rtop)):
975
- if j == b["rn"]:
976
- continue
977
- if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
978
- continue
979
- if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
980
- continue
981
- b["rowspan"].append(j)
982
-
983
- def join(arr):
984
- if not arr:
985
- return ""
986
- return "".join([t["text"] for t in arr])
987
-
988
- # rm the spaning cells
989
- for i in range(len(tbl)):
990
- for j, arr in enumerate(tbl[i]):
991
- if not arr:
992
- continue
993
- if all(["rowspan" not in a and "colspan" not in a for a in arr]):
994
- continue
995
- rowspan, colspan = [], []
996
- for a in arr:
997
- if isinstance(a.get("rowspan", 0), list):
998
- rowspan.extend(a["rowspan"])
999
- if isinstance(a.get("colspan", 0), list):
1000
- colspan.extend(a["colspan"])
1001
- rowspan, colspan = set(rowspan), set(colspan)
1002
- if len(rowspan) < 2 and len(colspan) < 2:
1003
- for a in arr:
1004
- if "rowspan" in a:
1005
- del a["rowspan"]
1006
- if "colspan" in a:
1007
- del a["colspan"]
1008
- continue
1009
- rowspan, colspan = sorted(rowspan), sorted(colspan)
1010
- rowspan = list(range(rowspan[0], rowspan[-1] + 1))
1011
- colspan = list(range(colspan[0], colspan[-1] + 1))
1012
- assert i in rowspan, rowspan
1013
- assert j in colspan, colspan
1014
- arr = []
1015
- for r in rowspan:
1016
- for c in colspan:
1017
- arr_txt = join(arr)
1018
- if tbl[r][c] and join(tbl[r][c]) != arr_txt:
1019
- arr.extend(tbl[r][c])
1020
- tbl[r][c] = None if html else arr
1021
- for a in arr:
1022
- if len(rowspan) > 1:
1023
- a["rowspan"] = len(rowspan)
1024
- elif "rowspan" in a:
1025
- del a["rowspan"]
1026
- if len(colspan) > 1:
1027
- a["colspan"] = len(colspan)
1028
- elif "colspan" in a:
1029
- del a["colspan"]
1030
- tbl[rowspan[0]][colspan[0]] = arr
1031
-
1032
- return tbl
1033
-
1034
- def __construct_table(self, boxes, html=False):
1035
- cap = ""
1036
- i = 0
1037
- while i < len(boxes):
1038
- if self.is_caption(boxes[i]):
1039
- cap += boxes[i]["text"]
1040
- boxes.pop(i)
1041
- i -= 1
1042
- i += 1
1043
-
1044
- if not boxes:
1045
- return []
1046
- for b in boxes:
1047
- b["btype"] = self._blockType(b)
1048
- max_type = Counter([b["btype"] for b in boxes]).items()
1049
- max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
1050
- logging.debug("MAXTYPE: " + max_type)
1051
-
1052
- rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
1053
- rowh = np.min(rowh) if rowh else 0
1054
- # boxes = self.sort_Y_firstly(boxes, rowh/5)
1055
- boxes = self.sort_R_firstly(boxes, rowh / 2)
1056
- boxes[0]["rn"] = 0
1057
- rows = [[boxes[0]]]
1058
- btm = boxes[0]["bottom"]
1059
- for b in boxes[1:]:
1060
- b["rn"] = len(rows) - 1
1061
- lst_r = rows[-1]
1062
- if lst_r[-1].get("R", "") != b.get("R", "") \
1063
- or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
1064
- ): # new row
1065
- btm = b["bottom"]
1066
- b["rn"] += 1
1067
- rows.append([b])
1068
- continue
1069
- btm = (btm + b["bottom"]) / 2.
1070
- rows[-1].append(b)
1071
-
1072
- colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
1073
- colwm = np.min(colwm) if colwm else 0
1074
- crosspage = len(set([b["page_number"] for b in boxes])) > 1
1075
- if crosspage:
1076
- boxes = self.sort_X_firstly(boxes, colwm / 2, False)
1077
- else:
1078
- boxes = self.sort_C_firstly(boxes, colwm / 2)
1079
- boxes[0]["cn"] = 0
1080
- cols = [[boxes[0]]]
1081
- right = boxes[0]["x1"]
1082
- for b in boxes[1:]:
1083
- b["cn"] = len(cols) - 1
1084
- lst_c = cols[-1]
1085
- if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][
1086
- "page_number"]) \
1087
- or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col
1088
- right = b["x1"]
1089
- b["cn"] += 1
1090
- cols.append([b])
1091
- continue
1092
- right = (right + b["x1"]) / 2.
1093
- cols[-1].append(b)
1094
-
1095
- tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
1096
- for b in boxes:
1097
- tbl[b["rn"]][b["cn"]].append(b)
1098
-
1099
- if len(rows) >= 4:
1100
- # remove single in column
1101
- j = 0
1102
- while j < len(tbl[0]):
1103
- e, ii = 0, 0
1104
- for i in range(len(tbl)):
1105
- if tbl[i][j]:
1106
- e += 1
1107
- ii = i
1108
- if e > 1:
1109
- break
1110
- if e > 1:
1111
- j += 1
1112
- continue
1113
- f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
1114
- [j - 1][0].get("text")) or j == 0
1115
- ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
1116
- [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
1117
- if f and ff:
1118
- j += 1
1119
- continue
1120
- bx = tbl[ii][j][0]
1121
- logging.debug("Relocate column single: " + bx["text"])
1122
- # j column only has one value
1123
- left, right = 100000, 100000
1124
- if j > 0 and not f:
1125
- for i in range(len(tbl)):
1126
- if tbl[i][j - 1]:
1127
- left = min(left, np.min(
1128
- [bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
1129
- if j + 1 < len(tbl[0]) and not ff:
1130
- for i in range(len(tbl)):
1131
- if tbl[i][j + 1]:
1132
- right = min(right, np.min(
1133
- [a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
1134
- assert left < 100000 or right < 100000
1135
- if left < right:
1136
- for jj in range(j, len(tbl[0])):
1137
- for i in range(len(tbl)):
1138
- for a in tbl[i][jj]:
1139
- a["cn"] -= 1
1140
- if tbl[ii][j - 1]:
1141
- tbl[ii][j - 1].extend(tbl[ii][j])
1142
- else:
1143
- tbl[ii][j - 1] = tbl[ii][j]
1144
- for i in range(len(tbl)):
1145
- tbl[i].pop(j)
1146
-
1147
- else:
1148
- for jj in range(j + 1, len(tbl[0])):
1149
- for i in range(len(tbl)):
1150
- for a in tbl[i][jj]:
1151
- a["cn"] -= 1
1152
- if tbl[ii][j + 1]:
1153
- tbl[ii][j + 1].extend(tbl[ii][j])
1154
- else:
1155
- tbl[ii][j + 1] = tbl[ii][j]
1156
- for i in range(len(tbl)):
1157
- tbl[i].pop(j)
1158
- cols.pop(j)
1159
- assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (
1160
- len(cols), len(tbl[0]))
1161
-
1162
- if len(cols) >= 4:
1163
- # remove single in row
1164
- i = 0
1165
- while i < len(tbl):
1166
- e, jj = 0, 0
1167
- for j in range(len(tbl[i])):
1168
- if tbl[i][j]:
1169
- e += 1
1170
- jj = j
1171
- if e > 1:
1172
- break
1173
- if e > 1:
1174
- i += 1
1175
- continue
1176
- f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
1177
- [jj][0].get("text")) or i == 0
1178
- ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
1179
- [jj][0].get("text")) or i + 1 >= len(tbl)
1180
- if f and ff:
1181
- i += 1
1182
- continue
1183
-
1184
- bx = tbl[i][jj][0]
1185
- logging.debug("Relocate row single: " + bx["text"])
1186
- # i row only has one value
1187
- up, down = 100000, 100000
1188
- if i > 0 and not f:
1189
- for j in range(len(tbl[i - 1])):
1190
- if tbl[i - 1][j]:
1191
- up = min(up, np.min(
1192
- [bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
1193
- if i + 1 < len(tbl) and not ff:
1194
- for j in range(len(tbl[i + 1])):
1195
- if tbl[i + 1][j]:
1196
- down = min(down, np.min(
1197
- [a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
1198
- assert up < 100000 or down < 100000
1199
- if up < down:
1200
- for ii in range(i, len(tbl)):
1201
- for j in range(len(tbl[ii])):
1202
- for a in tbl[ii][j]:
1203
- a["rn"] -= 1
1204
- if tbl[i - 1][jj]:
1205
- tbl[i - 1][jj].extend(tbl[i][jj])
1206
- else:
1207
- tbl[i - 1][jj] = tbl[i][jj]
1208
- tbl.pop(i)
1209
-
1210
- else:
1211
- for ii in range(i + 1, len(tbl)):
1212
- for j in range(len(tbl[ii])):
1213
- for a in tbl[ii][j]:
1214
- a["rn"] -= 1
1215
- if tbl[i + 1][jj]:
1216
- tbl[i + 1][jj].extend(tbl[i][jj])
1217
- else:
1218
- tbl[i + 1][jj] = tbl[i][jj]
1219
- tbl.pop(i)
1220
- rows.pop(i)
1221
-
1222
- # which rows are headers
1223
- hdset = set([])
1224
- for i in range(len(tbl)):
1225
- cnt, h = 0, 0
1226
- for j, arr in enumerate(tbl[i]):
1227
- if not arr:
1228
- continue
1229
- cnt += 1
1230
- if max_type == "Nu" and arr[0]["btype"] == "Nu":
1231
- continue
1232
- if any([a.get("H") for a in arr]) \
1233
- or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
1234
- h += 1
1235
- if h / cnt > 0.5:
1236
- hdset.add(i)
1237
-
1238
- if html:
1239
- return [self.__html_table(cap, hdset,
1240
- self.__cal_spans(boxes, rows,
1241
- cols, tbl, True)
1242
- )]
1243
-
1244
- return self.__desc_table(cap, hdset,
1245
- self.__cal_spans(boxes, rows, cols, tbl, False))
1246
-
1247
- def __html_table(self, cap, hdset, tbl):
1248
- # constrcut HTML
1249
- html = "<table>"
1250
- if cap:
1251
- html += f"<caption>{cap}</caption>"
1252
- for i in range(len(tbl)):
1253
- row = "<tr>"
1254
- txts = []
1255
- for j, arr in enumerate(tbl[i]):
1256
- if arr is None:
1257
- continue
1258
- if not arr:
1259
- row += "<td></td>" if i not in hdset else "<th></th>"
1260
- continue
1261
- txt = ""
1262
- if arr:
1263
- h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2,
1264
- self.mean_height[arr[0]["page_number"] - 1] / 2)
1265
- txt = "".join([c["text"]
1266
- for c in self.sort_Y_firstly(arr, h)])
1267
- txts.append(txt)
1268
- sp = ""
1269
- if arr[0].get("colspan"):
1270
- sp = "colspan={}".format(arr[0]["colspan"])
1271
- if arr[0].get("rowspan"):
1272
- sp += " rowspan={}".format(arr[0]["rowspan"])
1273
- if i in hdset:
1274
- row += f"<th {sp} >" + txt + "</th>"
1275
- else:
1276
- row += f"<td {sp} >" + txt + "</td>"
1277
-
1278
- if i in hdset:
1279
- if all([t in hdset for t in txts]):
1280
- continue
1281
- for t in txts:
1282
- hdset.add(t)
1283
-
1284
- if row != "<tr>":
1285
- row += "</tr>"
1286
- else:
1287
- row = ""
1288
- html += "\n" + row
1289
- html += "\n</table>"
1290
- return html
1291
-
1292
- def __desc_table(self, cap, hdr_rowno, tbl):
1293
- # get text of every colomn in header row to become header text
1294
- clmno = len(tbl[0])
1295
- rowno = len(tbl)
1296
- headers = {}
1297
- hdrset = set()
1298
- lst_hdr = []
1299
- de = "的" if not self.is_english else " for "
1300
- for r in sorted(list(hdr_rowno)):
1301
- headers[r] = ["" for _ in range(clmno)]
1302
- for i in range(clmno):
1303
- if not tbl[r][i]:
1304
- continue
1305
- txt = "".join([a["text"].strip() for a in tbl[r][i]])
1306
- headers[r][i] = txt
1307
- hdrset.add(txt)
1308
- if all([not t for t in headers[r]]):
1309
- del headers[r]
1310
- hdr_rowno.remove(r)
1311
- continue
1312
- for j in range(clmno):
1313
- if headers[r][j]:
1314
- continue
1315
- if j >= len(lst_hdr):
1316
- break
1317
- headers[r][j] = lst_hdr[j]
1318
- lst_hdr = headers[r]
1319
- for i in range(rowno):
1320
- if i not in hdr_rowno:
1321
- continue
1322
- for j in range(i + 1, rowno):
1323
- if j not in hdr_rowno:
1324
- break
1325
- for k in range(clmno):
1326
- if not headers[j - 1][k]:
1327
- continue
1328
- if headers[j][k].find(headers[j - 1][k]) >= 0:
1329
- continue
1330
- if len(headers[j][k]) > len(headers[j - 1][k]):
1331
- headers[j][k] += (de if headers[j][k]
1332
- else "") + headers[j - 1][k]
1333
- else:
1334
- headers[j][k] = headers[j - 1][k] \
1335
- + (de if headers[j - 1][k] else "") \
1336
- + headers[j][k]
1337
-
1338
- logging.debug(
1339
- f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
1340
- row_txt = []
1341
- for i in range(rowno):
1342
- if i in hdr_rowno:
1343
- continue
1344
- rtxt = []
1345
-
1346
- def append(delimer):
1347
- nonlocal rtxt, row_txt
1348
- rtxt = delimer.join(rtxt)
1349
- if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
1350
- row_txt[-1] += "\n" + rtxt
1351
- else:
1352
- row_txt.append(rtxt)
1353
-
1354
- r = 0
1355
- if len(headers.items()):
1356
- _arr = [(i - r, r) for r, _ in headers.items() if r < i]
1357
- if _arr:
1358
- _, r = min(_arr, key=lambda x: x[0])
1359
-
1360
- if r not in headers and clmno <= 2:
1361
- for j in range(clmno):
1362
- if not tbl[i][j]:
1363
- continue
1364
- txt = "".join([a["text"].strip() for a in tbl[i][j]])
1365
- if txt:
1366
- rtxt.append(txt)
1367
- if rtxt:
1368
- append(":")
1369
- continue
1370
-
1371
- for j in range(clmno):
1372
- if not tbl[i][j]:
1373
- continue
1374
- txt = "".join([a["text"].strip() for a in tbl[i][j]])
1375
- if not txt:
1376
- continue
1377
- ctt = headers[r][j] if r in headers else ""
1378
- if ctt:
1379
- ctt += ":"
1380
- ctt += txt
1381
- if ctt:
1382
- rtxt.append(ctt)
1383
-
1384
- if rtxt:
1385
- row_txt.append("; ".join(rtxt))
1386
-
1387
- if cap:
1388
- if self.is_english:
1389
- from_ = " in "
1390
- else:
1391
- from_ = "来自"
1392
- row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt]
1393
- return row_txt
1394
-
1395
- @staticmethod
1396
- def is_caption(bx):
1397
- patt = [
1398
- r"[图表]+[ 0-9::]{2,}"
1399
- ]
1400
- if any([re.match(p, bx["text"].strip()) for p in patt]) \
1401
- or bx["layout_type"].find("caption") >= 0:
1402
- return True
1403
- return False
1404
-
1405
  def _extract_table_figure(self, need_image, ZM, return_html):
1406
  tables = {}
1407
  figures = {}
@@ -1415,7 +557,7 @@ class HuParser:
1415
  continue
1416
  lout_no = str(self.boxes[i]["page_number"]) + \
1417
  "-" + str(self.boxes[i]["layoutno"])
1418
- if self.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
1419
  "figure caption", "reference"]:
1420
  nomerge_lout_no.append(lst_lout_no)
1421
  if self.boxes[i]["layout_type"] == "table":
@@ -1470,7 +612,7 @@ class HuParser:
1470
  while i < len(self.boxes):
1471
  c = self.boxes[i]
1472
  # mh = self.mean_height[c["page_number"]-1]
1473
- if not self.is_caption(c):
1474
  i += 1
1475
  continue
1476
 
@@ -1529,7 +671,7 @@ class HuParser:
1529
  "bottom": np.max([b["bottom"] for b in bxs]) - ht
1530
  }
1531
  louts = [l for l in self.page_layout[pn] if l["type"] == ltype]
1532
- ii = self.__find_overlapped(b, louts, naive=True)
1533
  if ii is not None:
1534
  b = louts[ii]
1535
  else:
@@ -1581,7 +723,7 @@ class HuParser:
1581
  if not bxs:
1582
  continue
1583
  res.append((cropout(bxs, "table"),
1584
- self.__construct_table(bxs, html=return_html)))
1585
 
1586
  return res
1587
 
 
1
  # -*- coding: utf-8 -*-
 
2
  import random
3
 
4
  import fitz
 
5
  import xgboost as xgb
6
  from io import BytesIO
7
  import torch
 
12
  import numpy as np
13
 
14
  from api.db import ParserType
15
+ from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer
16
  from rag.nlp import huqie
 
17
  from copy import deepcopy
18
  from huggingface_hub import hf_hub_download
19
 
 
26
  self.ocr = OCR()
27
  if not hasattr(self, "model_speciess"):
28
  self.model_speciess = ParserType.GENERAL.value
29
+ self.layouter = LayoutRecognizer("layout."+self.model_speciess)
30
+ self.tbl_det = TableStructureRecognizer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  self.updown_cnt_mdl = xgb.Booster()
33
  if torch.cuda.is_available():
 
46
 
47
  """
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def __char_width(self, c):
50
  return (c["x1"] - c["x0"]) // len(c["text"])
51
 
 
131
  ]
132
  return fea
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  @staticmethod
135
  def sort_X_by_page(arr, threashold):
136
  # sort using y1 first and then x1
 
146
  arr[j + 1] = tmp
147
  return arr
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def _has_color(self, o):
150
  if o.get("ncs", "") == "DeviceGray":
151
  if o["stroking_color"] and o["stroking_color"][0] == 1 and o["non_stroking_color"] and \
 
154
  return False
155
  return True
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def _table_transformer_job(self, ZM):
158
  logging.info("Table processing...")
159
  imgs, pos = [], []
 
179
  assert len(self.page_images) == len(tbcnt) - 1
180
  if not imgs:
181
  return
182
+ recos = self.tbl_det(imgs)
183
  tbcnt = np.cumsum(tbcnt)
184
  for i in range(len(tbcnt) - 1): # for page
185
  pg = []
 
201
  self.tb_cpns.extend(pg)
202
 
203
  def gather(kwd, fzy=10, ption=0.6):
204
+ eles = Recognizer.sort_Y_firstly(
205
  [r for r in self.tb_cpns if re.match(kwd, r["label"])], fzy)
206
+ eles = Recognizer.layouts_cleanup(self.boxes, eles, 5, ption)
207
+ return Recognizer.sort_Y_firstly(eles, 0)
208
 
209
  # add R,H,C,SP tag to boxes within table layout
210
  headers = gather(r".*header$")
 
212
  spans = gather(r".*spanning")
213
  clmns = sorted([r for r in self.tb_cpns if re.match(
214
  r"table column$", r["label"])], key=lambda x: (x["pn"], x["layoutno"], x["x0"]))
215
+ clmns = Recognizer.layouts_cleanup(self.boxes, clmns, 5, 0.5)
216
  for b in self.boxes:
217
  if b.get("layout_type", "") != "table":
218
  continue
219
+ ii = Recognizer.find_overlapped_with_threashold(b, rows, thr=0.3)
220
  if ii is not None:
221
  b["R"] = ii
222
  b["R_top"] = rows[ii]["top"]
223
  b["R_bott"] = rows[ii]["bottom"]
224
 
225
+ ii = Recognizer.find_overlapped_with_threashold(b, headers, thr=0.3)
226
  if ii is not None:
227
  b["H_top"] = headers[ii]["top"]
228
  b["H_bott"] = headers[ii]["bottom"]
 
230
  b["H_right"] = headers[ii]["x1"]
231
  b["H"] = ii
232
 
233
+ ii = Recognizer.find_overlapped_with_threashold(b, clmns, thr=0.3)
234
  if ii is not None:
235
  b["C"] = ii
236
  b["C_left"] = clmns[ii]["x0"]
237
  b["C_right"] = clmns[ii]["x1"]
238
 
239
+ ii = Recognizer.find_overlapped_with_threashold(b, spans, thr=0.3)
240
  if ii is not None:
241
  b["H_top"] = spans[ii]["top"]
242
  b["H_bott"] = spans[ii]["bottom"]
 
250
  self.boxes.append([])
251
  return
252
  bxs = [(line[0], line[1][0]) for line in bxs]
253
+ bxs = Recognizer.sort_Y_firstly(
254
  [{"x0": b[0][0] / ZM, "x1": b[1][0] / ZM,
255
  "top": b[0][1] / ZM, "text": "", "txt": t,
256
  "bottom": b[-1][1] / ZM,
 
259
  )
260
 
261
  # merge chars in the same rect
262
+ for c in Recognizer.sort_X_firstly(chars, self.mean_width[pagenum - 1] // 4):
263
+ ii = Recognizer.find_overlapped(c, bxs)
264
  if ii is None:
265
  self.lefted_chars.append(c)
266
  continue
 
281
  if self.mean_height[-1] == 0:
282
  self.mean_height[-1] = np.median([b["bottom"] - b["top"]
283
  for b in bxs])
 
284
  self.boxes.append(bxs)
285
 
286
  def _layouts_rec(self, ZM):
287
  assert len(self.page_images) == len(self.boxes)
288
+ self.boxes, self.page_layout = self.layouter(self.page_images, self.boxes, ZM)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  # cumlative Y
290
  for i in range(len(self.boxes)):
291
  self.boxes[i]["top"] += \
 
338
  self.boxes = bxs
339
 
340
  def _naive_vertical_merge(self):
341
+ bxs = Recognizer.sort_Y_firstly(self.boxes, np.median(self.mean_height) / 3)
342
  i = 0
343
  while i + 1 < len(bxs):
344
  b = bxs[i]
 
478
  t["layout_type"] = c["layout_type"]
479
  boxes.append(t)
480
 
481
+ self.boxes = Recognizer.sort_Y_firstly(boxes, 0)
482
 
483
  def _filter_forpages(self):
484
  if not self.boxes:
 
544
  b_["top"] = b["top"]
545
  self.boxes.pop(i)
546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
547
  def _extract_table_figure(self, need_image, ZM, return_html):
548
  tables = {}
549
  figures = {}
 
557
  continue
558
  lout_no = str(self.boxes[i]["page_number"]) + \
559
  "-" + str(self.boxes[i]["layoutno"])
560
+ if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", "title",
561
  "figure caption", "reference"]:
562
  nomerge_lout_no.append(lst_lout_no)
563
  if self.boxes[i]["layout_type"] == "table":
 
612
  while i < len(self.boxes):
613
  c = self.boxes[i]
614
  # mh = self.mean_height[c["page_number"]-1]
615
+ if not TableStructureRecognizer.is_caption(c):
616
  i += 1
617
  continue
618
 
 
671
  "bottom": np.max([b["bottom"] for b in bxs]) - ht
672
  }
673
  louts = [l for l in self.page_layout[pn] if l["type"] == ltype]
674
+ ii = Recognizer.find_overlapped(b, louts, naive=True)
675
  if ii is not None:
676
  b = louts[ii]
677
  else:
 
723
  if not bxs:
724
  continue
725
  res.append((cropout(bxs, "table"),
726
+ self.tbl_det.construct_table(bxs, html=return_html, is_english=self.is_english)))
727
 
728
  return res
729
 
deepdoc/vision/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .ocr import OCR
2
+ from .recognizer import Recognizer
3
+ from .layout_recognizer import LayoutRecognizer
4
+ from .table_structure_recognizer import TableStructureRecognizer
deepdoc/vision/layout_recognizer.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from collections import Counter
4
+ from copy import deepcopy
5
+
6
+ import numpy as np
7
+
8
+ from api.utils.file_utils import get_project_base_directory
9
+ from .recognizer import Recognizer
10
+
11
+
12
+ class LayoutRecognizer(Recognizer):
13
+ def __init__(self, domain):
14
+ self.layout_labels = [
15
+ "_background_",
16
+ "Text",
17
+ "Title",
18
+ "Figure",
19
+ "Figure caption",
20
+ "Table",
21
+ "Table caption",
22
+ "Header",
23
+ "Footer",
24
+ "Reference",
25
+ "Equation",
26
+ ]
27
+ super().__init__(self.layout_labels, domain,
28
+ os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
29
+
30
+ def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16):
31
+ def __is_garbage(b):
32
+ patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
33
+ r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
34
+ "(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
35
+ "\\(cid *: *[0-9]+ *\\)"
36
+ ]
37
+ return any([re.search(p, b["text"]) for p in patt])
38
+
39
+ layouts = super().__call__(image_list, thr, batch_size)
40
+ # save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7)
41
+ assert len(image_list) == len(ocr_res)
42
+ # Tag layout type
43
+ boxes = []
44
+ assert len(image_list) == len(layouts)
45
+ garbages = {}
46
+ page_layout = []
47
+ for pn, lts in enumerate(layouts):
48
+ bxs = ocr_res[pn]
49
+ lts = [{"type": b["type"],
50
+ "score": float(b["score"]),
51
+ "x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
52
+ "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
53
+ "page_number": pn,
54
+ } for b in lts]
55
+ lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2)
56
+ lts = self.layouts_cleanup(bxs, lts)
57
+ page_layout.append(lts)
58
+
59
+ # Tag layout type, layouts are ready
60
+ def findLayout(ty):
61
+ nonlocal bxs, lts, self
62
+ lts_ = [lt for lt in lts if lt["type"] == ty]
63
+ i = 0
64
+ while i < len(bxs):
65
+ if bxs[i].get("layout_type"):
66
+ i += 1
67
+ continue
68
+ if __is_garbage(bxs[i]):
69
+ bxs.pop(i)
70
+ continue
71
+
72
+ ii = self.find_overlapped_with_threashold(bxs[i], lts_,
73
+ thr=0.4)
74
+ if ii is None: # belong to nothing
75
+ bxs[i]["layout_type"] = ""
76
+ i += 1
77
+ continue
78
+ lts_[ii]["visited"] = True
79
+ if lts_[ii]["type"] in ["footer", "header", "reference"]:
80
+ if lts_[ii]["type"] not in garbages:
81
+ garbages[lts_[ii]["type"]] = []
82
+ garbages[lts_[ii]["type"]].append(bxs[i]["text"])
83
+ bxs.pop(i)
84
+ continue
85
+
86
+ bxs[i]["layoutno"] = f"{ty}-{ii}"
87
+ bxs[i]["layout_type"] = lts_[ii]["type"]
88
+ i += 1
89
+
90
+ for lt in ["footer", "header", "reference", "figure caption",
91
+ "table caption", "title", "text", "table", "figure", "equation"]:
92
+ findLayout(lt)
93
+
94
+ # add box to figure layouts which has not text box
95
+ for i, lt in enumerate(
96
+ [lt for lt in lts if lt["type"] == "figure"]):
97
+ if lt.get("visited"):
98
+ continue
99
+ lt = deepcopy(lt)
100
+ del lt["type"]
101
+ lt["text"] = ""
102
+ lt["layout_type"] = "figure"
103
+ lt["layoutno"] = f"figure-{i}"
104
+ bxs.append(lt)
105
+
106
+ boxes.extend(bxs)
107
+
108
+ ocr_res = boxes
109
+
110
+ garbag_set = set()
111
+ for k in garbages.keys():
112
+ garbages[k] = Counter(garbages[k])
113
+ for g, c in garbages[k].items():
114
+ if c > 1:
115
+ garbag_set.add(g)
116
+
117
+ ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
118
+ return ocr_res, page_layout
119
+
deepdoc/{visual → vision}/ocr.py RENAMED
@@ -74,7 +74,7 @@ class TextRecognizer(object):
74
  self.rec_batch_num = 16
75
  postprocess_params = {
76
  'name': 'CTCLabelDecode',
77
- "character_dict_path": os.path.join(get_project_base_directory(), "rag/res", "ocr.res"),
78
  "use_space_char": True
79
  }
80
  self.postprocess_op = build_post_process(postprocess_params)
@@ -450,7 +450,7 @@ class OCR(object):
450
 
451
  """
452
  if not model_dir:
453
- model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
454
 
455
  self.text_detector = TextDetector(model_dir)
456
  self.text_recognizer = TextRecognizer(model_dir)
 
74
  self.rec_batch_num = 16
75
  postprocess_params = {
76
  'name': 'CTCLabelDecode',
77
+ "character_dict_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "ocr.res"),
78
  "use_space_char": True
79
  }
80
  self.postprocess_op = build_post_process(postprocess_params)
 
450
 
451
  """
452
  if not model_dir:
453
+ model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
454
 
455
  self.text_detector = TextDetector(model_dir)
456
  self.text_recognizer = TextRecognizer(model_dir)
deepdoc/{visual → vision}/ocr.res RENAMED
File without changes
deepdoc/{visual → vision}/operators.py RENAMED
File without changes
deepdoc/{visual → vision}/postprocess.py RENAMED
File without changes
deepdoc/{visual → vision}/recognizer.py RENAMED
@@ -12,9 +12,12 @@
12
  #
13
 
14
  import os
 
 
15
  import onnxruntime as ort
16
  from huggingface_hub import snapshot_download
17
 
 
18
  from .operators import *
19
  from rag.settings import cron_logger
20
 
@@ -45,6 +48,140 @@ class Recognizer(object):
45
  self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
46
  self.label_list = label_list
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def create_inputs(self, imgs, im_info):
49
  """generate input for different model type
50
  Args:
@@ -85,6 +222,58 @@ class Recognizer(object):
85
  inputs['image'] = np.stack(padding_imgs, axis=0)
86
  return inputs
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def preprocess(self, image_list):
89
  preprocess_ops = []
90
  for op_info in [
@@ -103,7 +292,6 @@ class Recognizer(object):
103
  inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
104
  return inputs
105
 
106
-
107
  def __call__(self, image_list, thr=0.7, batch_size=16):
108
  res = []
109
  imgs = []
 
12
  #
13
 
14
  import os
15
+ from copy import deepcopy
16
+
17
  import onnxruntime as ort
18
  from huggingface_hub import snapshot_download
19
 
20
+ from . import seeit
21
  from .operators import *
22
  from rag.settings import cron_logger
23
 
 
48
  self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
49
  self.label_list = label_list
50
 
51
+ @staticmethod
52
+ def sort_Y_firstly(arr, threashold):
53
+ # sort using y1 first and then x1
54
+ arr = sorted(arr, key=lambda r: (r["top"], r["x0"]))
55
+ for i in range(len(arr) - 1):
56
+ for j in range(i, -1, -1):
57
+ # restore the order using th
58
+ if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \
59
+ and arr[j + 1]["x0"] < arr[j]["x0"]:
60
+ tmp = deepcopy(arr[j])
61
+ arr[j] = deepcopy(arr[j + 1])
62
+ arr[j + 1] = deepcopy(tmp)
63
+ return arr
64
+
65
+ @staticmethod
66
+ def sort_X_firstly(arr, threashold, copy=True):
67
+ # sort using y1 first and then x1
68
+ arr = sorted(arr, key=lambda r: (r["x0"], r["top"]))
69
+ for i in range(len(arr) - 1):
70
+ for j in range(i, -1, -1):
71
+ # restore the order using th
72
+ if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \
73
+ and arr[j + 1]["top"] < arr[j]["top"]:
74
+ tmp = deepcopy(arr[j]) if copy else arr[j]
75
+ arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1]
76
+ arr[j + 1] = deepcopy(tmp) if copy else tmp
77
+ return arr
78
+
79
+ @staticmethod
80
+ def sort_C_firstly(arr, thr=0):
81
+ # sort using y1 first and then x1
82
+ # sorted(arr, key=lambda r: (r["x0"], r["top"]))
83
+ arr = Recognizer.sort_X_firstly(arr, thr)
84
+ for i in range(len(arr) - 1):
85
+ for j in range(i, -1, -1):
86
+ # restore the order using th
87
+ if "C" not in arr[j] or "C" not in arr[j + 1]:
88
+ continue
89
+ if arr[j + 1]["C"] < arr[j]["C"] \
90
+ or (
91
+ arr[j + 1]["C"] == arr[j]["C"]
92
+ and arr[j + 1]["top"] < arr[j]["top"]
93
+ ):
94
+ tmp = arr[j]
95
+ arr[j] = arr[j + 1]
96
+ arr[j + 1] = tmp
97
+ return arr
98
+
99
+ return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"]))
100
+
101
+ @staticmethod
102
+ def sort_R_firstly(arr, thr=0):
103
+ # sort using y1 first and then x1
104
+ # sorted(arr, key=lambda r: (r["top"], r["x0"]))
105
+ arr = Recognizer.sort_Y_firstly(arr, thr)
106
+ for i in range(len(arr) - 1):
107
+ for j in range(i, -1, -1):
108
+ if "R" not in arr[j] or "R" not in arr[j + 1]:
109
+ continue
110
+ if arr[j + 1]["R"] < arr[j]["R"] \
111
+ or (
112
+ arr[j + 1]["R"] == arr[j]["R"]
113
+ and arr[j + 1]["x0"] < arr[j]["x0"]
114
+ ):
115
+ tmp = arr[j]
116
+ arr[j] = arr[j + 1]
117
+ arr[j + 1] = tmp
118
+ return arr
119
+
120
+ @staticmethod
121
+ def overlapped_area(a, b, ratio=True):
122
+ tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
123
+ if b["x0"] > x1 or b["x1"] < x0:
124
+ return 0
125
+ if b["bottom"] < tp or b["top"] > btm:
126
+ return 0
127
+ x0_ = max(b["x0"], x0)
128
+ x1_ = min(b["x1"], x1)
129
+ assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format(
130
+ tp, btm, x0, x1, b)
131
+ tp_ = max(b["top"], tp)
132
+ btm_ = min(b["bottom"], btm)
133
+ assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
134
+ tp, btm, x0, x1, b)
135
+ ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
136
+ x0 != 0 and btm - tp != 0 else 0
137
+ if ov > 0 and ratio:
138
+ ov /= (x1 - x0) * (btm - tp)
139
+ return ov
140
+
141
+ @staticmethod
142
+ def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
143
+ def notOverlapped(a, b):
144
+ return any([a["x1"] < b["x0"],
145
+ a["x0"] > b["x1"],
146
+ a["bottom"] < b["top"],
147
+ a["top"] > b["bottom"]])
148
+
149
+ i = 0
150
+ while i + 1 < len(layouts):
151
+ j = i + 1
152
+ while j < min(i + far, len(layouts)) \
153
+ and (layouts[i].get("type", "") != layouts[j].get("type", "")
154
+ or notOverlapped(layouts[i], layouts[j])):
155
+ j += 1
156
+ if j >= min(i + far, len(layouts)):
157
+ i += 1
158
+ continue
159
+ if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
160
+ and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
161
+ i += 1
162
+ continue
163
+
164
+ if layouts[i].get("score") and layouts[j].get("score"):
165
+ if layouts[i]["score"] > layouts[j]["score"]:
166
+ layouts.pop(j)
167
+ else:
168
+ layouts.pop(i)
169
+ continue
170
+
171
+ area_i, area_i_1 = 0, 0
172
+ for b in boxes:
173
+ if not notOverlapped(b, layouts[i]):
174
+ area_i += Recognizer.overlapped_area(b, layouts[i], False)
175
+ if not notOverlapped(b, layouts[j]):
176
+ area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
177
+
178
+ if area_i > area_i_1:
179
+ layouts.pop(j)
180
+ else:
181
+ layouts.pop(i)
182
+
183
+ return layouts
184
+
185
  def create_inputs(self, imgs, im_info):
186
  """generate input for different model type
187
  Args:
 
222
  inputs['image'] = np.stack(padding_imgs, axis=0)
223
  return inputs
224
 
225
+ @staticmethod
226
+ def find_overlapped(box, boxes_sorted_by_y, naive=False):
227
+ if not boxes_sorted_by_y:
228
+ return
229
+ bxs = boxes_sorted_by_y
230
+ s, e, ii = 0, len(bxs), 0
231
+ while s < e and not naive:
232
+ ii = (e + s) // 2
233
+ pv = bxs[ii]
234
+ if box["bottom"] < pv["top"]:
235
+ e = ii
236
+ continue
237
+ if box["top"] > pv["bottom"]:
238
+ s = ii + 1
239
+ continue
240
+ break
241
+ while s < ii:
242
+ if box["top"] > bxs[s]["bottom"]:
243
+ s += 1
244
+ break
245
+ while e - 1 > ii:
246
+ if box["bottom"] < bxs[e - 1]["top"]:
247
+ e -= 1
248
+ break
249
+
250
+ max_overlaped_i, max_overlaped = None, 0
251
+ for i in range(s, e):
252
+ ov = Recognizer.overlapped_area(bxs[i], box)
253
+ if ov <= max_overlaped:
254
+ continue
255
+ max_overlaped_i = i
256
+ max_overlaped = ov
257
+
258
+ return max_overlaped_i
259
+
260
+ @staticmethod
261
+ def find_overlapped_with_threashold(box, boxes, thr=0.3):
262
+ if not boxes:
263
+ return
264
+ max_overlaped_i, max_overlaped, _max_overlaped = None, thr, 0
265
+ s, e = 0, len(boxes)
266
+ for i in range(s, e):
267
+ ov = Recognizer.overlapped_area(box, boxes[i])
268
+ _ov = Recognizer.overlapped_area(boxes[i], box)
269
+ if (ov, _ov) < (max_overlaped, _max_overlaped):
270
+ continue
271
+ max_overlaped_i = i
272
+ max_overlaped = ov
273
+ _max_overlaped = _ov
274
+
275
+ return max_overlaped_i
276
+
277
  def preprocess(self, image_list):
278
  preprocess_ops = []
279
  for op_info in [
 
292
  inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
293
  return inputs
294
 
 
295
  def __call__(self, image_list, thr=0.7, batch_size=16):
296
  res = []
297
  imgs = []
deepdoc/{visual → vision}/seeit.py RENAMED
File without changes
deepdoc/vision/table_structure_recognizer.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+ from collections import Counter
5
+ from copy import deepcopy
6
+
7
+ import numpy as np
8
+
9
+ from api.utils.file_utils import get_project_base_directory
10
+ from rag.nlp import huqie
11
+ from .recognizer import Recognizer
12
+
13
+
14
+ class TableStructureRecognizer(Recognizer):
15
+ def __init__(self):
16
+ self.labels = [
17
+ "table",
18
+ "table column",
19
+ "table row",
20
+ "table column header",
21
+ "table projected row header",
22
+ "table spanning cell",
23
+ ]
24
+ super().__init__(self.labels, "tsr",
25
+ os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
26
+
27
+ def __call__(self, images, thr=0.5):
28
+ tbls = super().__call__(images, thr)
29
+ res = []
30
+ # align left&right for rows, align top&bottom for columns
31
+ for tbl in tbls:
32
+ lts = [{"label": b["type"],
33
+ "score": b["score"],
34
+ "x0": b["bbox"][0], "x1": b["bbox"][2],
35
+ "top": b["bbox"][1], "bottom": b["bbox"][-1]
36
+ } for b in tbl]
37
+ if not lts:
38
+ continue
39
+
40
+ left = [b["x0"] for b in lts if b["label"].find(
41
+ "row") > 0 or b["label"].find("header") > 0]
42
+ right = [b["x1"] for b in lts if b["label"].find(
43
+ "row") > 0 or b["label"].find("header") > 0]
44
+ if not left:
45
+ continue
46
+ left = np.median(left) if len(left) > 4 else np.min(left)
47
+ right = np.median(right) if len(right) > 4 else np.max(right)
48
+ for b in lts:
49
+ if b["label"].find("row") > 0 or b["label"].find("header") > 0:
50
+ if b["x0"] > left:
51
+ b["x0"] = left
52
+ if b["x1"] < right:
53
+ b["x1"] = right
54
+
55
+ top = [b["top"] for b in lts if b["label"] == "table column"]
56
+ bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
57
+ if not top:
58
+ res.append(lts)
59
+ continue
60
+ top = np.median(top) if len(top) > 4 else np.min(top)
61
+ bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
62
+ for b in lts:
63
+ if b["label"] == "table column":
64
+ if b["top"] > top:
65
+ b["top"] = top
66
+ if b["bottom"] < bottom:
67
+ b["bottom"] = bottom
68
+
69
+ res.append(lts)
70
+ return res
71
+
72
+ @staticmethod
73
+ def is_caption(bx):
74
+ patt = [
75
+ r"[图表]+[ 0-9::]{2,}"
76
+ ]
77
+ if any([re.match(p, bx["text"].strip()) for p in patt]) \
78
+ or bx["layout_type"].find("caption") >= 0:
79
+ return True
80
+ return False
81
+
82
+ def __blockType(self, b):
83
+ patt = [
84
+ ("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
85
+ (r"^(20|19)[0-9]{2}年$", "Dt"),
86
+ (r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
87
+ ("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
88
+ (r"^第*[一二三四1-4]季度$", "Dt"),
89
+ (r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
90
+ (r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
91
+ ("^[0-9.,+%/ -]+$", "Nu"),
92
+ (r"^[0-9A-Z/\._~-]+$", "Ca"),
93
+ (r"^[A-Z]*[a-z' -]+$", "En"),
94
+ (r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
95
+ (r"^.{1}$", "Sg")
96
+ ]
97
+ for p, n in patt:
98
+ if re.search(p, b["text"].strip()):
99
+ return n
100
+ tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1]
101
+ if len(tks) > 3:
102
+ if len(tks) < 12:
103
+ return "Tx"
104
+ else:
105
+ return "Lx"
106
+
107
+ if len(tks) == 1 and huqie.tag(tks[0]) == "nr":
108
+ return "Nr"
109
+
110
+ return "Ot"
111
+
112
+ def construct_table(self, boxes, is_english=False, html=False):
113
+ cap = ""
114
+ i = 0
115
+ while i < len(boxes):
116
+ if self.is_caption(boxes[i]):
117
+ cap += boxes[i]["text"]
118
+ boxes.pop(i)
119
+ i -= 1
120
+ i += 1
121
+
122
+ if not boxes:
123
+ return []
124
+ for b in boxes:
125
+ b["btype"] = self.__blockType(b)
126
+ max_type = Counter([b["btype"] for b in boxes]).items()
127
+ max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
128
+ logging.debug("MAXTYPE: " + max_type)
129
+
130
+ rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
131
+ rowh = np.min(rowh) if rowh else 0
132
+ boxes = self.sort_R_firstly(boxes, rowh / 2)
133
+ boxes[0]["rn"] = 0
134
+ rows = [[boxes[0]]]
135
+ btm = boxes[0]["bottom"]
136
+ for b in boxes[1:]:
137
+ b["rn"] = len(rows) - 1
138
+ lst_r = rows[-1]
139
+ if lst_r[-1].get("R", "") != b.get("R", "") \
140
+ or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
141
+ ): # new row
142
+ btm = b["bottom"]
143
+ b["rn"] += 1
144
+ rows.append([b])
145
+ continue
146
+ btm = (btm + b["bottom"]) / 2.
147
+ rows[-1].append(b)
148
+
149
+ colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
150
+ colwm = np.min(colwm) if colwm else 0
151
+ crosspage = len(set([b["page_number"] for b in boxes])) > 1
152
+ if crosspage:
153
+ boxes = self.sort_X_firstly(boxes, colwm / 2, False)
154
+ else:
155
+ boxes = self.sort_C_firstly(boxes, colwm / 2)
156
+ boxes[0]["cn"] = 0
157
+ cols = [[boxes[0]]]
158
+ right = boxes[0]["x1"]
159
+ for b in boxes[1:]:
160
+ b["cn"] = len(cols) - 1
161
+ lst_c = cols[-1]
162
+ if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][
163
+ "page_number"]) \
164
+ or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col
165
+ right = b["x1"]
166
+ b["cn"] += 1
167
+ cols.append([b])
168
+ continue
169
+ right = (right + b["x1"]) / 2.
170
+ cols[-1].append(b)
171
+
172
+ tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
173
+ for b in boxes:
174
+ tbl[b["rn"]][b["cn"]].append(b)
175
+
176
+ if len(rows) >= 4:
177
+ # remove single in column
178
+ j = 0
179
+ while j < len(tbl[0]):
180
+ e, ii = 0, 0
181
+ for i in range(len(tbl)):
182
+ if tbl[i][j]:
183
+ e += 1
184
+ ii = i
185
+ if e > 1:
186
+ break
187
+ if e > 1:
188
+ j += 1
189
+ continue
190
+ f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
191
+ [j - 1][0].get("text")) or j == 0
192
+ ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
193
+ [j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
194
+ if f and ff:
195
+ j += 1
196
+ continue
197
+ bx = tbl[ii][j][0]
198
+ logging.debug("Relocate column single: " + bx["text"])
199
+ # j column only has one value
200
+ left, right = 100000, 100000
201
+ if j > 0 and not f:
202
+ for i in range(len(tbl)):
203
+ if tbl[i][j - 1]:
204
+ left = min(left, np.min(
205
+ [bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
206
+ if j + 1 < len(tbl[0]) and not ff:
207
+ for i in range(len(tbl)):
208
+ if tbl[i][j + 1]:
209
+ right = min(right, np.min(
210
+ [a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
211
+ assert left < 100000 or right < 100000
212
+ if left < right:
213
+ for jj in range(j, len(tbl[0])):
214
+ for i in range(len(tbl)):
215
+ for a in tbl[i][jj]:
216
+ a["cn"] -= 1
217
+ if tbl[ii][j - 1]:
218
+ tbl[ii][j - 1].extend(tbl[ii][j])
219
+ else:
220
+ tbl[ii][j - 1] = tbl[ii][j]
221
+ for i in range(len(tbl)):
222
+ tbl[i].pop(j)
223
+
224
+ else:
225
+ for jj in range(j + 1, len(tbl[0])):
226
+ for i in range(len(tbl)):
227
+ for a in tbl[i][jj]:
228
+ a["cn"] -= 1
229
+ if tbl[ii][j + 1]:
230
+ tbl[ii][j + 1].extend(tbl[ii][j])
231
+ else:
232
+ tbl[ii][j + 1] = tbl[ii][j]
233
+ for i in range(len(tbl)):
234
+ tbl[i].pop(j)
235
+ cols.pop(j)
236
+ assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (
237
+ len(cols), len(tbl[0]))
238
+
239
+ if len(cols) >= 4:
240
+ # remove single in row
241
+ i = 0
242
+ while i < len(tbl):
243
+ e, jj = 0, 0
244
+ for j in range(len(tbl[i])):
245
+ if tbl[i][j]:
246
+ e += 1
247
+ jj = j
248
+ if e > 1:
249
+ break
250
+ if e > 1:
251
+ i += 1
252
+ continue
253
+ f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
254
+ [jj][0].get("text")) or i == 0
255
+ ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
256
+ [jj][0].get("text")) or i + 1 >= len(tbl)
257
+ if f and ff:
258
+ i += 1
259
+ continue
260
+
261
+ bx = tbl[i][jj][0]
262
+ logging.debug("Relocate row single: " + bx["text"])
263
+ # i row only has one value
264
+ up, down = 100000, 100000
265
+ if i > 0 and not f:
266
+ for j in range(len(tbl[i - 1])):
267
+ if tbl[i - 1][j]:
268
+ up = min(up, np.min(
269
+ [bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
270
+ if i + 1 < len(tbl) and not ff:
271
+ for j in range(len(tbl[i + 1])):
272
+ if tbl[i + 1][j]:
273
+ down = min(down, np.min(
274
+ [a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
275
+ assert up < 100000 or down < 100000
276
+ if up < down:
277
+ for ii in range(i, len(tbl)):
278
+ for j in range(len(tbl[ii])):
279
+ for a in tbl[ii][j]:
280
+ a["rn"] -= 1
281
+ if tbl[i - 1][jj]:
282
+ tbl[i - 1][jj].extend(tbl[i][jj])
283
+ else:
284
+ tbl[i - 1][jj] = tbl[i][jj]
285
+ tbl.pop(i)
286
+
287
+ else:
288
+ for ii in range(i + 1, len(tbl)):
289
+ for j in range(len(tbl[ii])):
290
+ for a in tbl[ii][j]:
291
+ a["rn"] -= 1
292
+ if tbl[i + 1][jj]:
293
+ tbl[i + 1][jj].extend(tbl[i][jj])
294
+ else:
295
+ tbl[i + 1][jj] = tbl[i][jj]
296
+ tbl.pop(i)
297
+ rows.pop(i)
298
+
299
+ # which rows are headers
300
+ hdset = set([])
301
+ for i in range(len(tbl)):
302
+ cnt, h = 0, 0
303
+ for j, arr in enumerate(tbl[i]):
304
+ if not arr:
305
+ continue
306
+ cnt += 1
307
+ if max_type == "Nu" and arr[0]["btype"] == "Nu":
308
+ continue
309
+ if any([a.get("H") for a in arr]) \
310
+ or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
311
+ h += 1
312
+ if h / cnt > 0.5:
313
+ hdset.add(i)
314
+
315
+ if html:
316
+ return [self.__html_table(cap, hdset,
317
+ self.__cal_spans(boxes, rows,
318
+ cols, tbl, True)
319
+ )]
320
+
321
+ return self.__desc_table(cap, hdset,
322
+ self.__cal_spans(boxes, rows, cols, tbl, False),
323
+ is_english)
324
+
325
+ def __html_table(self, cap, hdset, tbl):
326
+ # constrcut HTML
327
+ html = "<table>"
328
+ if cap:
329
+ html += f"<caption>{cap}</caption>"
330
+ for i in range(len(tbl)):
331
+ row = "<tr>"
332
+ txts = []
333
+ for j, arr in enumerate(tbl[i]):
334
+ if arr is None:
335
+ continue
336
+ if not arr:
337
+ row += "<td></td>" if i not in hdset else "<th></th>"
338
+ continue
339
+ txt = ""
340
+ if arr:
341
+ h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
342
+ txt = "".join([c["text"]
343
+ for c in self.sort_Y_firstly(arr, h)])
344
+ txts.append(txt)
345
+ sp = ""
346
+ if arr[0].get("colspan"):
347
+ sp = "colspan={}".format(arr[0]["colspan"])
348
+ if arr[0].get("rowspan"):
349
+ sp += " rowspan={}".format(arr[0]["rowspan"])
350
+ if i in hdset:
351
+ row += f"<th {sp} >" + txt + "</th>"
352
+ else:
353
+ row += f"<td {sp} >" + txt + "</td>"
354
+
355
+ if i in hdset:
356
+ if all([t in hdset for t in txts]):
357
+ continue
358
+ for t in txts:
359
+ hdset.add(t)
360
+
361
+ if row != "<tr>":
362
+ row += "</tr>"
363
+ else:
364
+ row = ""
365
+ html += "\n" + row
366
+ html += "\n</table>"
367
+ return html
368
+
369
+ def __desc_table(self, cap, hdr_rowno, tbl, is_english):
370
+ # get text of every colomn in header row to become header text
371
+ clmno = len(tbl[0])
372
+ rowno = len(tbl)
373
+ headers = {}
374
+ hdrset = set()
375
+ lst_hdr = []
376
+ de = "的" if not is_english else " for "
377
+ for r in sorted(list(hdr_rowno)):
378
+ headers[r] = ["" for _ in range(clmno)]
379
+ for i in range(clmno):
380
+ if not tbl[r][i]:
381
+ continue
382
+ txt = "".join([a["text"].strip() for a in tbl[r][i]])
383
+ headers[r][i] = txt
384
+ hdrset.add(txt)
385
+ if all([not t for t in headers[r]]):
386
+ del headers[r]
387
+ hdr_rowno.remove(r)
388
+ continue
389
+ for j in range(clmno):
390
+ if headers[r][j]:
391
+ continue
392
+ if j >= len(lst_hdr):
393
+ break
394
+ headers[r][j] = lst_hdr[j]
395
+ lst_hdr = headers[r]
396
+ for i in range(rowno):
397
+ if i not in hdr_rowno:
398
+ continue
399
+ for j in range(i + 1, rowno):
400
+ if j not in hdr_rowno:
401
+ break
402
+ for k in range(clmno):
403
+ if not headers[j - 1][k]:
404
+ continue
405
+ if headers[j][k].find(headers[j - 1][k]) >= 0:
406
+ continue
407
+ if len(headers[j][k]) > len(headers[j - 1][k]):
408
+ headers[j][k] += (de if headers[j][k]
409
+ else "") + headers[j - 1][k]
410
+ else:
411
+ headers[j][k] = headers[j - 1][k] \
412
+ + (de if headers[j - 1][k] else "") \
413
+ + headers[j][k]
414
+
415
+ logging.debug(
416
+ f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
417
+ row_txt = []
418
+ for i in range(rowno):
419
+ if i in hdr_rowno:
420
+ continue
421
+ rtxt = []
422
+
423
+ def append(delimer):
424
+ nonlocal rtxt, row_txt
425
+ rtxt = delimer.join(rtxt)
426
+ if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
427
+ row_txt[-1] += "\n" + rtxt
428
+ else:
429
+ row_txt.append(rtxt)
430
+
431
+ r = 0
432
+ if len(headers.items()):
433
+ _arr = [(i - r, r) for r, _ in headers.items() if r < i]
434
+ if _arr:
435
+ _, r = min(_arr, key=lambda x: x[0])
436
+
437
+ if r not in headers and clmno <= 2:
438
+ for j in range(clmno):
439
+ if not tbl[i][j]:
440
+ continue
441
+ txt = "".join([a["text"].strip() for a in tbl[i][j]])
442
+ if txt:
443
+ rtxt.append(txt)
444
+ if rtxt:
445
+ append(":")
446
+ continue
447
+
448
+ for j in range(clmno):
449
+ if not tbl[i][j]:
450
+ continue
451
+ txt = "".join([a["text"].strip() for a in tbl[i][j]])
452
+ if not txt:
453
+ continue
454
+ ctt = headers[r][j] if r in headers else ""
455
+ if ctt:
456
+ ctt += ":"
457
+ ctt += txt
458
+ if ctt:
459
+ rtxt.append(ctt)
460
+
461
+ if rtxt:
462
+ row_txt.append("; ".join(rtxt))
463
+
464
+ if cap:
465
+ if is_english:
466
+ from_ = " in "
467
+ else:
468
+ from_ = "来自"
469
+ row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt]
470
+ return row_txt
471
+
472
+ def __cal_spans(self, boxes, rows, cols, tbl, html=True):
473
+ # caculate span
474
+ clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
475
+ for cln in cols]
476
+ crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln])
477
+ for cln in cols]
478
+ rtop = [np.mean([c.get("R_top", c["top"]) for c in row])
479
+ for row in rows]
480
+ rbtm = [np.mean([c.get("R_btm", c["bottom"])
481
+ for c in row]) for row in rows]
482
+ for b in boxes:
483
+ if "SP" not in b:
484
+ continue
485
+ b["colspan"] = [b["cn"]]
486
+ b["rowspan"] = [b["rn"]]
487
+ # col span
488
+ for j in range(0, len(clft)):
489
+ if j == b["cn"]:
490
+ continue
491
+ if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
492
+ continue
493
+ if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
494
+ continue
495
+ b["colspan"].append(j)
496
+ # row span
497
+ for j in range(0, len(rtop)):
498
+ if j == b["rn"]:
499
+ continue
500
+ if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
501
+ continue
502
+ if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
503
+ continue
504
+ b["rowspan"].append(j)
505
+
506
+ def join(arr):
507
+ if not arr:
508
+ return ""
509
+ return "".join([t["text"] for t in arr])
510
+
511
+ # rm the spaning cells
512
+ for i in range(len(tbl)):
513
+ for j, arr in enumerate(tbl[i]):
514
+ if not arr:
515
+ continue
516
+ if all(["rowspan" not in a and "colspan" not in a for a in arr]):
517
+ continue
518
+ rowspan, colspan = [], []
519
+ for a in arr:
520
+ if isinstance(a.get("rowspan", 0), list):
521
+ rowspan.extend(a["rowspan"])
522
+ if isinstance(a.get("colspan", 0), list):
523
+ colspan.extend(a["colspan"])
524
+ rowspan, colspan = set(rowspan), set(colspan)
525
+ if len(rowspan) < 2 and len(colspan) < 2:
526
+ for a in arr:
527
+ if "rowspan" in a:
528
+ del a["rowspan"]
529
+ if "colspan" in a:
530
+ del a["colspan"]
531
+ continue
532
+ rowspan, colspan = sorted(rowspan), sorted(colspan)
533
+ rowspan = list(range(rowspan[0], rowspan[-1] + 1))
534
+ colspan = list(range(colspan[0], colspan[-1] + 1))
535
+ assert i in rowspan, rowspan
536
+ assert j in colspan, colspan
537
+ arr = []
538
+ for r in rowspan:
539
+ for c in colspan:
540
+ arr_txt = join(arr)
541
+ if tbl[r][c] and join(tbl[r][c]) != arr_txt:
542
+ arr.extend(tbl[r][c])
543
+ tbl[r][c] = None if html else arr
544
+ for a in arr:
545
+ if len(rowspan) > 1:
546
+ a["rowspan"] = len(rowspan)
547
+ elif "rowspan" in a:
548
+ del a["rowspan"]
549
+ if len(colspan) > 1:
550
+ a["colspan"] = len(colspan)
551
+ elif "colspan" in a:
552
+ del a["colspan"]
553
+ tbl[rowspan[0]][colspan[0]] = arr
554
+
555
+ return tbl
556
+
deepdoc/visual/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .ocr import OCR
2
- from .recognizer import Recognizer
 
 
 
rag/svr/task_broker.py CHANGED
@@ -21,7 +21,7 @@ from datetime import datetime
21
  from api.db.db_models import Task
22
  from api.db.db_utils import bulk_insert_into_db
23
  from api.db.services.task_service import TaskService
24
- from deepdoc.parser import HuParser
25
  from rag.settings import cron_logger
26
  from rag.utils import MINIO
27
  from rag.utils import findMaxTm
@@ -80,7 +80,7 @@ def dispatch():
80
 
81
  tsks = []
82
  if r["type"] == FileType.PDF.value:
83
- pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
84
  for s,e in r["parser_config"].get("pages", [(0,100000)]):
85
  e = min(e, pages)
86
  for p in range(s, e, 10):
 
21
  from api.db.db_models import Task
22
  from api.db.db_utils import bulk_insert_into_db
23
  from api.db.services.task_service import TaskService
24
+ from deepdoc.parser import PdfParser
25
  from rag.settings import cron_logger
26
  from rag.utils import MINIO
27
  from rag.utils import findMaxTm
 
80
 
81
  tsks = []
82
  if r["type"] == FileType.PDF.value:
83
+ pages = PdfParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
84
  for s,e in r["parser_config"].get("pages", [(0,100000)]):
85
  e = min(e, pages)
86
  for p in range(s, e, 10):