|
from src.RAG import RAG |
|
from src.model import Pixtral |
|
from src.prompts import GENERATE_INVOICE_PROMPT, GENERATE_BRIEF_DAMAGE_DESCRIPTION_PROMPT, \ |
|
GENERATE_DETAILED_DAMAGE_DESCRIPTION_PROMPT |
|
from md2pdf.core import md2pdf |
|
from rerankers import Reranker |
|
import re |
|
from fuzzywuzzy import fuzz |
|
|
|
|
|
class InvoiceGenerator: |
|
def __init__( |
|
self, |
|
fais_index_path, |
|
image_invoice_index_path, |
|
path_to_invoices, |
|
path_to_images, |
|
reranker_model=None, |
|
device="cuda", |
|
max_model_len=4096, max_tokens=2048, gpu_memory_utilization=0.95 |
|
): |
|
self.model = Pixtral(max_model_len=max_model_len, max_tokens=max_tokens, |
|
gpu_memory_utilization=gpu_memory_utilization) |
|
if reranker_model: |
|
self.reranker = Reranker(model_name=reranker_model, device=device) |
|
|
|
self.device = device |
|
self.rag = RAG( |
|
fais_index_path=fais_index_path, |
|
image_invoice_index_path=image_invoice_index_path, |
|
path_to_invoices=path_to_invoices, |
|
path_to_images=path_to_images, |
|
reranker=self.reranker |
|
) |
|
self.path_to_invoices = path_to_invoices |
|
self.path_to_images = path_to_images |
|
|
|
def format_invoice(self, generated_invoice, output_path, template_path="data/template.md"): |
|
with open(template_path, "r") as f: |
|
md_text = f.read() |
|
md_text = md_text.replace(r"<<table>>", generated_invoice) |
|
md2pdf(output_path, md_content=md_text) |
|
|
|
@staticmethod |
|
def check_within_range(generated_invoice, car_parts): |
|
def get_part_info(part_name, car_parts): |
|
part_name = part_name.lower() |
|
max_match = [None, 0] |
|
for part in car_parts: |
|
ratio = fuzz.WRatio(part_name, part.lower()) |
|
if ratio >= 90 and ratio > max_match[1]: |
|
max_match[0] = part |
|
max_match[1] = ratio |
|
return max_match[0] |
|
|
|
all_lines = generated_invoice.split("\n") |
|
first_cost_line = 3 if all_lines[0] == '' else 2 |
|
last_cost_line = -2 if all_lines[-1] == '' else -1 |
|
lines = generated_invoice.split("\n")[first_cost_line:last_cost_line] |
|
cost_lines = [[line.strip() for line in cost_line.split("|")] for cost_line in lines] |
|
|
|
comparing_results = {} |
|
|
|
for line in cost_lines: |
|
part = line[0] |
|
cost = line[1] |
|
hours = line[2] |
|
found_part = get_part_info(part, car_parts) |
|
if found_part: |
|
comparing_results[part] = { |
|
"cost_within_range": car_parts[found_part]["cost_range"][0] <= float(cost) <= |
|
car_parts[found_part]["cost_range"][1], |
|
"hours_within_range": car_parts[found_part]["hours_range"][0] <= float(hours) <= |
|
car_parts[found_part]["hours_range"][1], |
|
"cost_diff": float(cost) - car_parts[found_part]["average_cost"], |
|
"hours_diff": float(hours) - car_parts[found_part]["average_hours"], |
|
"part_info": found_part |
|
} |
|
else: |
|
comparing_results[part] = {} |
|
|
|
return comparing_results |
|
|
|
@staticmethod |
|
def check_calculations(generated_invoice): |
|
all_lines = generated_invoice.split("\n") |
|
first_cost_line = 3 if all_lines[0] == '' else 2 |
|
last_cost_line = -2 if all_lines[-1] == '' else -1 |
|
total_cost_line = all_lines[last_cost_line] |
|
lines = generated_invoice.split("\n")[first_cost_line:last_cost_line] |
|
cost_lines = [[line.strip() for line in cost_line.split("|")] for cost_line in lines] |
|
costs = [int(line[1]) + int(line[2]) * int(line[3]) for line in cost_lines] |
|
cost_lines = list(map(lambda x, y: [x[0], x[1], x[2], x[3], str(y)], cost_lines, costs)) |
|
total_cost = sum(costs) |
|
total_cost_line = re.sub(r"\d+", f"{total_cost}", total_cost_line) |
|
|
|
all_lines[last_cost_line] = total_cost_line |
|
all_lines[first_cost_line:last_cost_line] = list(map(lambda x: " | ".join(x), cost_lines)) |
|
return "\n".join(all_lines) |
|
|
|
def generate_invoice(self, image_path, output_path=None, template_path="data/template.md", car_parts=None): |
|
|
|
result = {} |
|
|
|
damage_description = self.model.generate_message_from_image( |
|
GENERATE_BRIEF_DAMAGE_DESCRIPTION_PROMPT, image_path |
|
) |
|
if damage_description == "Irrelevant." or len(damage_description.split()) < 5: |
|
return None |
|
|
|
result["damage_description"] = damage_description |
|
|
|
print("Damage Description:", damage_description) |
|
|
|
invoice_info, invoice_path = self.rag.find_invoice( |
|
image_path=image_path, return_only_path=False, damage_description=damage_description, k=5 |
|
) |
|
invoice_info = invoice_info[0] |
|
invoice_path = invoice_path[0] |
|
result["invoice_info"] = invoice_info |
|
result["invoice_path"] = invoice_path |
|
result["similar_image"] = invoice_path.replace(".pdf", ".png") |
|
|
|
print("Invoice Path:", invoice_path) |
|
|
|
detailed_damage_description = self.model.generate_message_from_image( |
|
GENERATE_DETAILED_DAMAGE_DESCRIPTION_PROMPT, image_path |
|
) |
|
result["detailed_damage_description"] = detailed_damage_description |
|
|
|
print("Detailed Damage Description:", detailed_damage_description) |
|
|
|
generated_invoice = self.model.generate_message_from_image( |
|
GENERATE_INVOICE_PROMPT(invoice_info, detailed_damage_description), image_path |
|
).replace("```markdown", "").replace("```", "") |
|
generated_invoice = self.check_calculations(generated_invoice) |
|
|
|
result["generated_invoice"] = generated_invoice |
|
|
|
if car_parts: |
|
comparing_results = self.check_within_range(generated_invoice, car_parts) |
|
result["comparing_results"] = comparing_results |
|
print(comparing_results) |
|
|
|
if output_path: |
|
self.format_invoice(generated_invoice=generated_invoice, output_path=output_path, |
|
template_path=template_path) |
|
|
|
return result |
|
|