File size: 5,268 Bytes
3bf331b
0f164ad
 
3bf331b
 
 
 
 
 
 
 
0f164ad
 
3bf331b
 
 
34b23f6
 
3bf331b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34b23f6
0f164ad
 
 
3bf331b
 
0f164ad
 
 
 
 
 
 
 
 
3bf331b
0f164ad
 
 
3bf331b
 
 
34b23f6
 
0f164ad
 
3bf331b
34b23f6
3bf331b
 
 
 
 
 
 
 
0f164ad
 
 
3bf331b
 
 
 
 
0f164ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from transformers import DonutProcessor, VisionEncoderDecoderModel
import pkg_resources
from symspellpy import SymSpell
from word2number import w2n
from dateutil import relativedelta
from datetime import datetime
from word2number import w2n
from PIL import Image
import torch
import re

CHEQUE_PARSER_MODEL = "shivi/donut-cheque-parser"
TASK_PROMPT = "<parse-cheque>"
device = "cuda" if torch.cuda.is_available() else "cpu"

def load_donut_model_and_processor():
    donut_processor = DonutProcessor.from_pretrained(CHEQUE_PARSER_MODEL)
    model = VisionEncoderDecoderModel.from_pretrained(CHEQUE_PARSER_MODEL)
    model.to(device)
    return donut_processor, model

def prepare_data_using_processor(donut_processor,image_path):
    ## Pass image through donut processor's feature extractor and retrieve image tensor
    image = load_image(image_path)
    pixel_values = donut_processor(image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)

    ## Pass task prompt for document (cheque) parsing task to donut processor's tokenizer and retrieve the input_ids
    decoder_input_ids = donut_processor.tokenizer(TASK_PROMPT, add_special_tokens=False, return_tensors="pt")["input_ids"]
    decoder_input_ids = decoder_input_ids.to(device)

    return pixel_values, decoder_input_ids

def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    return image

def parse_cheque_with_donut(input_image_path):

    donut_processor, model = load_donut_model_and_processor()

    cheque_image_tensor, input_for_decoder = prepare_data_using_processor(donut_processor,input_image_path)
    
    outputs = model.generate(cheque_image_tensor,
                                decoder_input_ids=input_for_decoder,
                                max_length=model.decoder.config.max_position_embeddings,
                                early_stopping=True,
                                pad_token_id=donut_processor.tokenizer.pad_token_id,
                                eos_token_id=donut_processor.tokenizer.eos_token_id,
                                use_cache=True,
                                num_beams=1,
                                bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
                                return_dict_in_generate=True,
                                output_scores=True,)

    decoded_output_sequence = donut_processor.batch_decode(outputs.sequences)[0]
    
    extracted_cheque_details = decoded_output_sequence.replace(donut_processor.tokenizer.eos_token, "").replace(donut_processor.tokenizer.pad_token, "")
    ## remove task prompt from token sequence
    cleaned_cheque_details = re.sub(r"<.*?>", "", extracted_cheque_details, count=1).strip()  
    ## generate ordered json sequence from output token sequence
    cheque_details_json = donut_processor.token2json(cleaned_cheque_details)
    print("cheque_details_json:",cheque_details_json['cheque_details'])
    
    ## extract required fields from predicted json

    amt_in_words  = cheque_details_json['cheque_details'][0]['amt_in_words']
    amt_in_figures = cheque_details_json['cheque_details'][1]['amt_in_figures']
    macthing_amts = match_legal_and_courstesy_amount(amt_in_words,amt_in_figures)
    
    payee_name = cheque_details_json['cheque_details'][2]['payee_name']

    bank_name = cheque_details_json['cheque_details'][3]['bank_name']
    cheque_date = cheque_details_json['cheque_details'][4]['cheque_date']

    stale_cheque = check_if_cheque_is_stale(cheque_date)

    return payee_name,amt_in_words,amt_in_figures,bank_name,cheque_date,macthing_amts,stale_cheque

def spell_check(amt_in_words):
    sym_spell = SymSpell(max_dictionary_edit_distance=2,prefix_length=7)
    dictionary_path = pkg_resources.resource_filename("symspellpy", "frequency_dictionary_82_765.txt")
    bigram_path = pkg_resources.resource_filename("symspellpy", "frequency_bigramdictionary_en_243_342.txt")

    sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)
    sym_spell.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)

    suggestions = sym_spell.lookup_compound(amt_in_words, max_edit_distance=2)

    return suggestions[0].term

def match_legal_and_courstesy_amount(legal_amount,courtesy_amount):
    macthing_amts = False
    if len(legal_amount) == 0:
        return macthing_amts

    corrected_amt_in_words = spell_check(legal_amount)
    print("corrected_amt_in_words:",corrected_amt_in_words)
    
    numeric_legal_amt = w2n.word_to_num(corrected_amt_in_words)
    print("numeric_legal_amt:",numeric_legal_amt)
    if int(numeric_legal_amt) == int(courtesy_amount):
        macthing_amts = True
    return macthing_amts

def check_if_cheque_is_stale(cheque_issue_date):
    stale_check = False
    current_date = datetime.now().strftime('%d/%m/%y')
    current_date_ = datetime.strptime(current_date, "%d/%m/%y")
    cheque_issue_date_ = datetime.strptime(cheque_issue_date, "%d/%m/%y")
    relative_diff = relativedelta.relativedelta(current_date_, cheque_issue_date_)
    months_difference = (relative_diff.years * 12) + relative_diff.months
    print("months_difference:",months_difference)
    if months_difference > 3:
        stale_check = True
    return stale_check