pixtral-demo / src /pipelines.py
alexandraroze's picture
updated comparing results
77ecec5
raw
history blame
6.67 kB
from src.RAG import RAG
from src.model import Pixtral
from src.prompts import GENERATE_INVOICE_PROMPT, GENERATE_BRIEF_DAMAGE_DESCRIPTION_PROMPT, \
GENERATE_DETAILED_DAMAGE_DESCRIPTION_PROMPT
from rerankers import Reranker
import re
from md2pdf.core import md2pdf
from fuzzywuzzy import fuzz
class InvoiceGenerator:
def __init__(
self,
fais_index_path,
image_invoice_index_path,
path_to_invoices,
path_to_images,
path_to_template,
reranker_model=None,
device="cuda",
invoice_json_path=None,
max_model_len=4096, max_tokens=2048, gpu_memory_utilization=0.95
):
self.model = Pixtral(max_model_len=max_model_len, max_tokens=max_tokens,
gpu_memory_utilization=gpu_memory_utilization)
if reranker_model:
self.reranker = Reranker(model_name=reranker_model, device=device)
self.device = device
self.rag = RAG(
fais_index_path=fais_index_path,
image_invoice_index_path=image_invoice_index_path,
path_to_invoices=path_to_invoices,
path_to_images=path_to_images,
reranker=self.reranker,
path_to_invoice_json=invoice_json_path
)
self.path_to_invoices = path_to_invoices
self.path_to_images = path_to_images
self.path_to_template = path_to_template
def format_invoice(self, generated_invoice, output_path):
with open(self.path_to_template, "r") as f:
md_text = f.read()
md_text = md_text.replace(r"<<table>>", generated_invoice)
md2pdf(output_path, md_content=md_text)
@staticmethod
def check_within_range(generated_invoice, car_parts):
print("Comparing results")
def get_part_info(part_name, car_parts):
part_name = part_name.lower()
max_match = [None, 0]
for part in car_parts:
ratio = fuzz.WRatio(part_name, part.lower())
if ratio >= 90 and ratio > max_match[1]:
max_match[0] = part
max_match[1] = ratio
return max_match[0]
all_lines = generated_invoice.split("\n")
first_cost_line = 3 if all_lines[0] == '' else 2
last_cost_line = -2 if all_lines[-1] == '' else -1
lines = generated_invoice.split("\n")[first_cost_line:last_cost_line]
cost_lines = [[line.strip() for line in cost_line.split("|")] for cost_line in lines]
start_index = 0 if cost_lines[0][0] != "" else 1
comparing_results = {}
for line in cost_lines:
part = line[start_index]
cost = line[start_index + 1]
hours = line[start_index + 2]
found_part = get_part_info(part, car_parts)
if found_part:
comparing_results[part] = {
"cost_within_range": car_parts[found_part]["cost_min"] <= float(cost) <=
car_parts[found_part]["cost_max"],
"hours_within_range": car_parts[found_part]["hours_min"] <= float(hours) <=
car_parts[found_part]["hours_max"],
"cost_diff": float(cost) - car_parts[found_part]["average_cost"],
"hours_diff": float(hours) - car_parts[found_part]["average_hours"],
"part_info": found_part
}
else:
comparing_results[part] = {}
return comparing_results
@staticmethod
def check_calculations(generated_invoice):
all_lines = generated_invoice.split("\n")
first_cost_line = 3 if all_lines[0] == '' else 2
last_cost_line = -2 if all_lines[-1] == '' else -1
total_cost_line = all_lines[last_cost_line]
lines = generated_invoice.split("\n")[first_cost_line:last_cost_line]
cost_lines = [[line.strip() for line in cost_line.split("|")] for cost_line in lines]
print(f"Cost lines: \n{cost_lines}\n")
start_index = 1 if cost_lines[0][0] != "" else 2
costs = [int(line[start_index]) + int(line[start_index + 1]) * int(line[start_index + 2]) for line in cost_lines]
cost_lines = list(map(lambda x, y: [x[0], x[1], x[2], x[3], str(y)], cost_lines, costs))
total_cost = sum(costs)
total_cost_line = re.sub(r"\d+", f"{total_cost}", total_cost_line)
all_lines[last_cost_line] = total_cost_line
all_lines[first_cost_line:last_cost_line] = list(map(lambda x: " | ".join(x), cost_lines))
return "\n".join(all_lines)
def generate_invoice(self, image_path, output_path=None, car_parts=None):
result = {}
damage_description = self.model.generate_message_from_image(
GENERATE_BRIEF_DAMAGE_DESCRIPTION_PROMPT, image_path
)
if damage_description == "Irrelevant." or len(damage_description.split()) < 5:
return None
result["damage_description"] = damage_description
print("Damage Description:", damage_description)
invoice_info, invoice_path = self.rag.find_invoice(
image_path=image_path, return_only_path=False, damage_description=damage_description, k=5
)
invoice_info = invoice_info[0]
invoice_path = invoice_path[0]
result["invoice_info"] = invoice_info
result["invoice_path"] = invoice_path
result["similar_image"] = invoice_path.replace(".pdf", ".png")
print("Invoice Path:", invoice_path)
detailed_damage_description = self.model.generate_message_from_image(
GENERATE_DETAILED_DAMAGE_DESCRIPTION_PROMPT, image_path
)
result["detailed_damage_description"] = detailed_damage_description
print("Detailed Damage Description:", detailed_damage_description)
generated_invoice = self.model.generate_message_from_image(
GENERATE_INVOICE_PROMPT(invoice_info, detailed_damage_description), image_path
).replace("```markdown", "").replace("```", "")
print(f"Generated invoice: \n{generated_invoice}\n")
generated_invoice = self.check_calculations(generated_invoice)
result["generated_invoice"] = generated_invoice
if car_parts:
comparing_results = self.check_within_range(generated_invoice, car_parts)
result["comparing_results"] = comparing_results
print(comparing_results)
if output_path:
self.format_invoice(generated_invoice=generated_invoice, output_path=output_path)
return result