File size: 5,237 Bytes
a228fac
 
 
 
0dd8e27
a228fac
 
12af45e
a228fac
 
 
 
0d655c9
a228fac
 
 
 
 
 
 
 
 
0dd8e27
a228fac
 
 
 
 
 
 
 
7b11b8d
 
 
 
a228fac
 
0d655c9
12af45e
 
 
 
 
 
 
 
a228fac
 
9d6cf42
 
12af45e
 
 
 
 
 
 
 
 
 
 
 
 
 
0d655c9
 
12af45e
 
 
 
 
 
 
 
 
 
 
 
a228fac
 
12af45e
a228fac
0dd8e27
a228fac
 
 
0dd8e27
a228fac
0dd8e27
a228fac
0dd8e27
 
 
 
 
 
 
 
 
 
 
a228fac
0dd8e27
a228fac
 
 
 
0dd8e27
 
 
 
 
 
 
a228fac
0dd8e27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from config import Settings
from preprocess import Preprocessor
import ocr
from PIL import Image
from transformers import LiltForTokenClassification, AutoTokenizer
import token_classification
import torch
from fastapi import FastAPI, UploadFile, Form, HTTPException
from contextlib import asynccontextmanager
import json
import io
from models import LiLTRobertaLikeForRelationExtraction
from base64 import b64decode 
config = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    settings = Settings()
    config['settings'] = settings
    config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config['vision_client'] = ocr.VisionClient(settings.GCV_AUTH)
    config['processor'] = Preprocessor(settings.TOKENIZER)
    config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
    config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
    config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
    yield
    # Clean up and release the resources
    config.clear()

app = FastAPI(lifespan=lifespan)

@app.get("/")
def api_home():
    return {'detail': 'Welcome to Sri-Doc space'}

@app.post("/submit-doc")
async def ProcessDocument(file: UploadFile):
  content = await file.read()
  ocr_df, image = ApplyOCR(content)
  if len(ocr_df) < 2:
    raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
  try:
    tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
    reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
  except:
    raise HTTPException(status_code=400, detail="Invalid Image")
  return reOutput

@app.post("/submit-doc-base64")
async def ProcessDocument(file: str = Form(...)):
  try:
    head, file = file.split(',')
    str_as_bytes = str.encode(file)
    content = b64decode(str_as_bytes)
  except:
    raise HTTPException(status_code=400, detail="Invalid image")
  ocr_df, image = ApplyOCR(content)
  if len(ocr_df) < 2:
    raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
  try:
    tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
    reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
  except:
    raise HTTPException(status_code=400, detail="Invalid Image")
  return reOutput

def ApplyOCR(content):
  try:
    image = Image.open(io.BytesIO(content))
  except:
    raise HTTPException(status_code=400, detail="Invalid image")
  try:
    ocr_df = config['vision_client'].ocr(content, image)
  except:
    raise HTTPException(status_code=400, detail="OCR process failed")
  return ocr_df, image

def LabelTokens(ocr_df, image):
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
  token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
  return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size

def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
  token_labels = tokenClassificationOutput['token_labels']
  input_ids = tokenClassificationOutput['input_ids']
  attention_mask = tokenClassificationOutput["attention_mask"]
  bbox_org = tokenClassificationOutput["bbox"]

  merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
  
  entities = merged_output['entities']
  input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
  bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
  attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device'])

  id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
  decoded_entities = []
  for entity in entities:
    decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']])))
    entity['label'] = id2label[entity['label']]

  config['re_model'].to(config['device'])
  entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]}
  relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
  with torch.no_grad():
    outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)

  decoded_pred_relations = []
  for relation in outputs.pred_relations[0]:
    head_start, head_end = relation['head']
    tail_start, tail_end = relation['tail']
    question =  config['tokenizer'].decode(input_ids[0][head_start:head_end])
    answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
    decoded_pred_relations.append((question, answer))

  return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}