File size: 6,306 Bytes
f11a85d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from src.RAG import RAG
from src.model import Pixtral
from src.prompts import GENERATE_INVOICE_PROMPT, GENERATE_BRIEF_DAMAGE_DESCRIPTION_PROMPT, \
GENERATE_DETAILED_DAMAGE_DESCRIPTION_PROMPT
from md2pdf.core import md2pdf
from rerankers import Reranker
import re
from fuzzywuzzy import fuzz
class InvoiceGenerator:
def __init__(
self,
fais_index_path,
image_invoice_index_path,
path_to_invoices,
path_to_images,
reranker_model=None,
device="cuda",
max_model_len=4096, max_tokens=2048, gpu_memory_utilization=0.95
):
self.model = Pixtral(max_model_len=max_model_len, max_tokens=max_tokens,
gpu_memory_utilization=gpu_memory_utilization)
if reranker_model:
self.reranker = Reranker(model_name=reranker_model, device=device)
self.device = device
self.rag = RAG(
fais_index_path=fais_index_path,
image_invoice_index_path=image_invoice_index_path,
path_to_invoices=path_to_invoices,
path_to_images=path_to_images,
reranker=self.reranker
)
self.path_to_invoices = path_to_invoices
self.path_to_images = path_to_images
def format_invoice(self, generated_invoice, output_path, template_path="data/template.md"):
with open(template_path, "r") as f:
md_text = f.read()
md_text = md_text.replace(r"<<table>>", generated_invoice)
md2pdf(output_path, md_content=md_text)
@staticmethod
def check_within_range(generated_invoice, car_parts):
def get_part_info(part_name, car_parts):
part_name = part_name.lower()
max_match = [None, 0]
for part in car_parts:
ratio = fuzz.WRatio(part_name, part.lower())
if ratio >= 90 and ratio > max_match[1]:
max_match[0] = part
max_match[1] = ratio
return max_match[0]
all_lines = generated_invoice.split("\n")
first_cost_line = 3 if all_lines[0] == '' else 2
last_cost_line = -2 if all_lines[-1] == '' else -1
lines = generated_invoice.split("\n")[first_cost_line:last_cost_line]
cost_lines = [[line.strip() for line in cost_line.split("|")] for cost_line in lines]
comparing_results = {}
for line in cost_lines:
part = line[0]
cost = line[1]
hours = line[2]
found_part = get_part_info(part, car_parts)
if found_part:
comparing_results[part] = {
"cost_within_range": car_parts[found_part]["cost_range"][0] <= float(cost) <=
car_parts[found_part]["cost_range"][1],
"hours_within_range": car_parts[found_part]["hours_range"][0] <= float(hours) <=
car_parts[found_part]["hours_range"][1],
"cost_diff": float(cost) - car_parts[found_part]["average_cost"],
"hours_diff": float(hours) - car_parts[found_part]["average_hours"],
"part_info": found_part
}
else:
comparing_results[part] = {}
return comparing_results
@staticmethod
def check_calculations(generated_invoice):
all_lines = generated_invoice.split("\n")
first_cost_line = 3 if all_lines[0] == '' else 2
last_cost_line = -2 if all_lines[-1] == '' else -1
total_cost_line = all_lines[last_cost_line]
lines = generated_invoice.split("\n")[first_cost_line:last_cost_line]
cost_lines = [[line.strip() for line in cost_line.split("|")] for cost_line in lines]
costs = [int(line[1]) + int(line[2]) * int(line[3]) for line in cost_lines]
cost_lines = list(map(lambda x, y: [x[0], x[1], x[2], x[3], str(y)], cost_lines, costs))
total_cost = sum(costs)
total_cost_line = re.sub(r"\d+", f"{total_cost}", total_cost_line)
all_lines[last_cost_line] = total_cost_line
all_lines[first_cost_line:last_cost_line] = list(map(lambda x: " | ".join(x), cost_lines))
return "\n".join(all_lines)
def generate_invoice(self, image_path, output_path=None, template_path="data/template.md", car_parts=None):
result = {}
damage_description = self.model.generate_message_from_image(
GENERATE_BRIEF_DAMAGE_DESCRIPTION_PROMPT, image_path
)
if damage_description == "Irrelevant." or len(damage_description.split()) < 5:
return None
result["damage_description"] = damage_description
print("Damage Description:", damage_description)
invoice_info, invoice_path = self.rag.find_invoice(
image_path=image_path, return_only_path=False, damage_description=damage_description, k=5
)
invoice_info = invoice_info[0]
invoice_path = invoice_path[0]
result["invoice_info"] = invoice_info
result["invoice_path"] = invoice_path
result["similar_image"] = invoice_path.replace(".pdf", ".png")
print("Invoice Path:", invoice_path)
detailed_damage_description = self.model.generate_message_from_image(
GENERATE_DETAILED_DAMAGE_DESCRIPTION_PROMPT, image_path
)
result["detailed_damage_description"] = detailed_damage_description
print("Detailed Damage Description:", detailed_damage_description)
generated_invoice = self.model.generate_message_from_image(
GENERATE_INVOICE_PROMPT(invoice_info, detailed_damage_description), image_path
).replace("```markdown", "").replace("```", "")
generated_invoice = self.check_calculations(generated_invoice)
result["generated_invoice"] = generated_invoice
if car_parts:
comparing_results = self.check_within_range(generated_invoice, car_parts)
result["comparing_results"] = comparing_results
print(comparing_results)
if output_path:
self.format_invoice(generated_invoice=generated_invoice, output_path=output_path,
template_path=template_path)
return result
|