pixtral-demo / src /pipelines.py
alexandraroze's picture
dockerfile
f11a85d
raw
history blame
6.31 kB
from src.RAG import RAG
from src.model import Pixtral
from src.prompts import GENERATE_INVOICE_PROMPT, GENERATE_BRIEF_DAMAGE_DESCRIPTION_PROMPT, \
GENERATE_DETAILED_DAMAGE_DESCRIPTION_PROMPT
from md2pdf.core import md2pdf
from rerankers import Reranker
import re
from fuzzywuzzy import fuzz
class InvoiceGenerator:
def __init__(
self,
fais_index_path,
image_invoice_index_path,
path_to_invoices,
path_to_images,
reranker_model=None,
device="cuda",
max_model_len=4096, max_tokens=2048, gpu_memory_utilization=0.95
):
self.model = Pixtral(max_model_len=max_model_len, max_tokens=max_tokens,
gpu_memory_utilization=gpu_memory_utilization)
if reranker_model:
self.reranker = Reranker(model_name=reranker_model, device=device)
self.device = device
self.rag = RAG(
fais_index_path=fais_index_path,
image_invoice_index_path=image_invoice_index_path,
path_to_invoices=path_to_invoices,
path_to_images=path_to_images,
reranker=self.reranker
)
self.path_to_invoices = path_to_invoices
self.path_to_images = path_to_images
def format_invoice(self, generated_invoice, output_path, template_path="data/template.md"):
with open(template_path, "r") as f:
md_text = f.read()
md_text = md_text.replace(r"<<table>>", generated_invoice)
md2pdf(output_path, md_content=md_text)
@staticmethod
def check_within_range(generated_invoice, car_parts):
def get_part_info(part_name, car_parts):
part_name = part_name.lower()
max_match = [None, 0]
for part in car_parts:
ratio = fuzz.WRatio(part_name, part.lower())
if ratio >= 90 and ratio > max_match[1]:
max_match[0] = part
max_match[1] = ratio
return max_match[0]
all_lines = generated_invoice.split("\n")
first_cost_line = 3 if all_lines[0] == '' else 2
last_cost_line = -2 if all_lines[-1] == '' else -1
lines = generated_invoice.split("\n")[first_cost_line:last_cost_line]
cost_lines = [[line.strip() for line in cost_line.split("|")] for cost_line in lines]
comparing_results = {}
for line in cost_lines:
part = line[0]
cost = line[1]
hours = line[2]
found_part = get_part_info(part, car_parts)
if found_part:
comparing_results[part] = {
"cost_within_range": car_parts[found_part]["cost_range"][0] <= float(cost) <=
car_parts[found_part]["cost_range"][1],
"hours_within_range": car_parts[found_part]["hours_range"][0] <= float(hours) <=
car_parts[found_part]["hours_range"][1],
"cost_diff": float(cost) - car_parts[found_part]["average_cost"],
"hours_diff": float(hours) - car_parts[found_part]["average_hours"],
"part_info": found_part
}
else:
comparing_results[part] = {}
return comparing_results
@staticmethod
def check_calculations(generated_invoice):
all_lines = generated_invoice.split("\n")
first_cost_line = 3 if all_lines[0] == '' else 2
last_cost_line = -2 if all_lines[-1] == '' else -1
total_cost_line = all_lines[last_cost_line]
lines = generated_invoice.split("\n")[first_cost_line:last_cost_line]
cost_lines = [[line.strip() for line in cost_line.split("|")] for cost_line in lines]
costs = [int(line[1]) + int(line[2]) * int(line[3]) for line in cost_lines]
cost_lines = list(map(lambda x, y: [x[0], x[1], x[2], x[3], str(y)], cost_lines, costs))
total_cost = sum(costs)
total_cost_line = re.sub(r"\d+", f"{total_cost}", total_cost_line)
all_lines[last_cost_line] = total_cost_line
all_lines[first_cost_line:last_cost_line] = list(map(lambda x: " | ".join(x), cost_lines))
return "\n".join(all_lines)
def generate_invoice(self, image_path, output_path=None, template_path="data/template.md", car_parts=None):
result = {}
damage_description = self.model.generate_message_from_image(
GENERATE_BRIEF_DAMAGE_DESCRIPTION_PROMPT, image_path
)
if damage_description == "Irrelevant." or len(damage_description.split()) < 5:
return None
result["damage_description"] = damage_description
print("Damage Description:", damage_description)
invoice_info, invoice_path = self.rag.find_invoice(
image_path=image_path, return_only_path=False, damage_description=damage_description, k=5
)
invoice_info = invoice_info[0]
invoice_path = invoice_path[0]
result["invoice_info"] = invoice_info
result["invoice_path"] = invoice_path
result["similar_image"] = invoice_path.replace(".pdf", ".png")
print("Invoice Path:", invoice_path)
detailed_damage_description = self.model.generate_message_from_image(
GENERATE_DETAILED_DAMAGE_DESCRIPTION_PROMPT, image_path
)
result["detailed_damage_description"] = detailed_damage_description
print("Detailed Damage Description:", detailed_damage_description)
generated_invoice = self.model.generate_message_from_image(
GENERATE_INVOICE_PROMPT(invoice_info, detailed_damage_description), image_path
).replace("```markdown", "").replace("```", "")
generated_invoice = self.check_calculations(generated_invoice)
result["generated_invoice"] = generated_invoice
if car_parts:
comparing_results = self.check_within_range(generated_invoice, car_parts)
result["comparing_results"] = comparing_results
print(comparing_results)
if output_path:
self.format_invoice(generated_invoice=generated_invoice, output_path=output_path,
template_path=template_path)
return result