#![allow(clippy::type_complexity)] use anyhow::Result; use image::{DynamicImage, GenericImageView, ImageBuffer}; use ndarray::{s, Array, Axis, IxDyn}; use rand::{thread_rng, Rng}; use std::path::PathBuf; use crate::{ check_font, gen_time_string, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend, OrtConfig, OrtEP, Point2, YOLOResult, YOLOTask, SKELETON, }; pub struct YOLOv8 { // YOLOv8 model for all yolo-tasks engine: OrtBackend, nc: u32, nk: u32, nm: u32, height: u32, width: u32, batch: u32, task: YOLOTask, conf: f32, kconf: f32, iou: f32, names: Vec, color_palette: Vec<(u8, u8, u8)>, profile: bool, plot: bool, } impl YOLOv8 { pub fn new(config: Args) -> Result { // execution provider let ep = if config.trt { OrtEP::Trt(config.device_id) } else if config.cuda { OrtEP::Cuda(config.device_id) } else { OrtEP::Cpu }; // batch let batch = Batch { opt: config.batch, min: config.batch_min, max: config.batch_max, }; // build ort engine let ort_args = OrtConfig { ep, batch, f: config.model, task: config.task, trt_fp16: config.fp16, image_size: (config.height, config.width), }; let engine = OrtBackend::build(ort_args)?; // get batch, height, width, tasks, nc, nk, nm let (batch, height, width, task) = ( engine.batch(), engine.height(), engine.width(), engine.task(), ); let nc = engine.nc().or(config.nc).unwrap_or_else(|| { panic!("Failed to get num_classes, make it explicit with `--nc`"); }); let (nk, nm) = match task { YOLOTask::Pose => { let nk = engine.nk().or(config.nk).unwrap_or_else(|| { panic!("Failed to get num_keypoints, make it explicit with `--nk`"); }); (nk, 0) } YOLOTask::Segment => { let nm = engine.nm().or(config.nm).unwrap_or_else(|| { panic!("Failed to get num_masks, make it explicit with `--nm`"); }); (0, nm) } _ => (0, 0), }; // class names let names = engine.names().unwrap_or(vec!["Unknown".to_string()]); // color palette let mut rng = thread_rng(); let color_palette: Vec<_> = names .iter() .map(|_| { ( rng.gen_range(0..=255), rng.gen_range(0..=255), rng.gen_range(0..=255), ) }) .collect(); Ok(Self { engine, names, conf: config.conf, kconf: config.kconf, iou: config.iou, color_palette, profile: config.profile, plot: config.plot, nc, nk, nm, height, width, batch, task, }) } pub fn scale_wh(&self, w0: f32, h0: f32, w1: f32, h1: f32) -> (f32, f32, f32) { let r = (w1 / w0).min(h1 / h0); (r, (w0 * r).round(), (h0 * r).round()) } pub fn preprocess(&mut self, xs: &Vec) -> Result> { let mut ys = Array::ones((xs.len(), 3, self.height() as usize, self.width() as usize)).into_dyn(); ys.fill(144.0 / 255.0); for (idx, x) in xs.iter().enumerate() { let img = match self.task() { YOLOTask::Classify => x.resize_exact( self.width(), self.height(), image::imageops::FilterType::Triangle, ), _ => { let (w0, h0) = x.dimensions(); let w0 = w0 as f32; let h0 = h0 as f32; let (_, w_new, h_new) = self.scale_wh(w0, h0, self.width() as f32, self.height() as f32); // f32 round x.resize_exact( w_new as u32, h_new as u32, if let YOLOTask::Segment = self.task() { image::imageops::FilterType::CatmullRom } else { image::imageops::FilterType::Triangle }, ) } }; for (x, y, rgb) in img.pixels() { let x = x as usize; let y = y as usize; let [r, g, b, _] = rgb.0; ys[[idx, 0, y, x]] = (r as f32) / 255.0; ys[[idx, 1, y, x]] = (g as f32) / 255.0; ys[[idx, 2, y, x]] = (b as f32) / 255.0; } } Ok(ys) } pub fn run(&mut self, xs: &Vec) -> Result> { // pre-process let t_pre = std::time::Instant::now(); let xs_ = self.preprocess(xs)?; if self.profile { println!("[Model Preprocess]: {:?}", t_pre.elapsed()); } // run let t_run = std::time::Instant::now(); let ys = self.engine.run(xs_, self.profile)?; if self.profile { println!("[Model Inference]: {:?}", t_run.elapsed()); } // post-process let t_post = std::time::Instant::now(); let ys = self.postprocess(ys, xs)?; if self.profile { println!("[Model Postprocess]: {:?}", t_post.elapsed()); } // plot and save if self.plot { self.plot_and_save(&ys, xs, Some(&SKELETON)); } Ok(ys) } pub fn postprocess( &self, xs: Vec>, xs0: &[DynamicImage], ) -> Result> { if let YOLOTask::Classify = self.task() { let mut ys = Vec::new(); let preds = &xs[0]; for batch in preds.axis_iter(Axis(0)) { ys.push(YOLOResult::new( Some(Embedding::new(batch.into_owned())), None, None, None, )); } Ok(ys) } else { const CXYWH_OFFSET: usize = 4; // cxcywh const KPT_STEP: usize = 3; // xyconf let preds = &xs[0]; let protos = { if xs.len() > 1 { Some(&xs[1]) } else { None } }; let mut ys = Vec::new(); for (idx, anchor) in preds.axis_iter(Axis(0)).enumerate() { // [bs, 4 + nc + nm, anchors] // input image let width_original = xs0[idx].width() as f32; let height_original = xs0[idx].height() as f32; let ratio = (self.width() as f32 / width_original) .min(self.height() as f32 / height_original); // save each result let mut data: Vec<(Bbox, Option>, Option>)> = Vec::new(); for pred in anchor.axis_iter(Axis(1)) { // split preds for different tasks let bbox = pred.slice(s![0..CXYWH_OFFSET]); let clss = pred.slice(s![CXYWH_OFFSET..CXYWH_OFFSET + self.nc() as usize]); let kpts = { if let YOLOTask::Pose = self.task() { Some(pred.slice(s![pred.len() - KPT_STEP * self.nk() as usize..])) } else { None } }; let coefs = { if let YOLOTask::Segment = self.task() { Some(pred.slice(s![pred.len() - self.nm() as usize..]).to_vec()) } else { None } }; // confidence and id let (id, &confidence) = clss .into_iter() .enumerate() .reduce(|max, x| if x.1 > max.1 { x } else { max }) .unwrap(); // definitely will not panic! // confidence filter if confidence < self.conf { continue; } // bbox re-scale let cx = bbox[0] / ratio; let cy = bbox[1] / ratio; let w = bbox[2] / ratio; let h = bbox[3] / ratio; let x = cx - w / 2.; let y = cy - h / 2.; let y_bbox = Bbox::new( x.max(0.0f32).min(width_original), y.max(0.0f32).min(height_original), w, h, id, confidence, ); // kpts let y_kpts = { if let Some(kpts) = kpts { let mut kpts_ = Vec::new(); // rescale for i in 0..self.nk() as usize { let kx = kpts[KPT_STEP * i] / ratio; let ky = kpts[KPT_STEP * i + 1] / ratio; let kconf = kpts[KPT_STEP * i + 2]; if kconf < self.kconf { kpts_.push(Point2::default()); } else { kpts_.push(Point2::new_with_conf( kx.max(0.0f32).min(width_original), ky.max(0.0f32).min(height_original), kconf, )); } } Some(kpts_) } else { None } }; // data merged data.push((y_bbox, y_kpts, coefs)); } // nms non_max_suppression(&mut data, self.iou); // decode let mut y_bboxes: Vec = Vec::new(); let mut y_kpts: Vec> = Vec::new(); let mut y_masks: Vec> = Vec::new(); for elem in data.into_iter() { if let Some(kpts) = elem.1 { y_kpts.push(kpts) } // decode masks if let Some(coefs) = elem.2 { let proto = protos.unwrap().slice(s![idx, .., .., ..]); let (nm, nh, nw) = proto.dim(); // coefs * proto -> mask let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm) let proto = proto.to_owned().into_shape((nm, nh * nw))?; // (nm, nh*nw) let mask = coefs.dot(&proto).into_shape((nh, nw, 1))?; // (nh, nw, n) // build image from ndarray let mask_im: ImageBuffer, Vec> = match ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec()) { Some(image) => image, None => panic!("can not create image from ndarray"), }; let mut mask_im = image::DynamicImage::from(mask_im); // -> dyn // rescale masks let (_, w_mask, h_mask) = self.scale_wh(width_original, height_original, nw as f32, nh as f32); let mask_cropped = mask_im.crop(0, 0, w_mask as u32, h_mask as u32); let mask_original = mask_cropped.resize_exact( // resize_to_fill width_original as u32, height_original as u32, match self.task() { YOLOTask::Segment => image::imageops::FilterType::CatmullRom, _ => image::imageops::FilterType::Triangle, }, ); // crop-mask with bbox let mut mask_original_cropped = mask_original.into_luma8(); for y in 0..height_original as usize { for x in 0..width_original as usize { if x < elem.0.xmin() as usize || x > elem.0.xmax() as usize || y < elem.0.ymin() as usize || y > elem.0.ymax() as usize { mask_original_cropped.put_pixel( x as u32, y as u32, image::Luma([0u8]), ); } } } y_masks.push(mask_original_cropped.into_raw()); } y_bboxes.push(elem.0); } // save each result let y = YOLOResult { probs: None, bboxes: if !y_bboxes.is_empty() { Some(y_bboxes) } else { None }, keypoints: if !y_kpts.is_empty() { Some(y_kpts) } else { None }, masks: if !y_masks.is_empty() { Some(y_masks) } else { None }, }; ys.push(y); } Ok(ys) } } pub fn plot_and_save( &self, ys: &[YOLOResult], xs0: &[DynamicImage], skeletons: Option<&[(usize, usize)]>, ) { // check font then load let font = check_font("Arial.ttf"); for (_idb, (img0, y)) in xs0.iter().zip(ys.iter()).enumerate() { let mut img = img0.to_rgb8(); // draw for classifier if let Some(probs) = y.probs() { for (i, k) in probs.topk(5).iter().enumerate() { let legend = format!("{} {:.2}%", self.names[k.0], k.1); let scale = 32; let legend_size = img.width().max(img.height()) / scale; let x = img.width() / 20; let y = img.height() / 20 + i as u32 * legend_size; imageproc::drawing::draw_text_mut( &mut img, image::Rgb([0, 255, 0]), x as i32, y as i32, rusttype::Scale::uniform(legend_size as f32 - 1.), &font, &legend, ); } } // draw bboxes & keypoints if let Some(bboxes) = y.bboxes() { for (_idx, bbox) in bboxes.iter().enumerate() { // rect imageproc::drawing::draw_hollow_rect_mut( &mut img, imageproc::rect::Rect::at(bbox.xmin() as i32, bbox.ymin() as i32) .of_size(bbox.width() as u32, bbox.height() as u32), image::Rgb(self.color_palette[bbox.id()].into()), ); // text let legend = format!("{} {:.2}%", self.names[bbox.id()], bbox.confidence()); let scale = 40; let legend_size = img.width().max(img.height()) / scale; imageproc::drawing::draw_text_mut( &mut img, image::Rgb(self.color_palette[bbox.id()].into()), bbox.xmin() as i32, (bbox.ymin() - legend_size as f32) as i32, rusttype::Scale::uniform(legend_size as f32 - 1.), &font, &legend, ); } } // draw kpts if let Some(keypoints) = y.keypoints() { for kpts in keypoints.iter() { for kpt in kpts.iter() { // filter if kpt.confidence() < self.kconf { continue; } // draw point imageproc::drawing::draw_filled_circle_mut( &mut img, (kpt.x() as i32, kpt.y() as i32), 2, image::Rgb([0, 255, 0]), ); } // draw skeleton if has if let Some(skeletons) = skeletons { for &(idx1, idx2) in skeletons.iter() { let kpt1 = &kpts[idx1]; let kpt2 = &kpts[idx2]; if kpt1.confidence() < self.kconf || kpt2.confidence() < self.kconf { continue; } imageproc::drawing::draw_line_segment_mut( &mut img, (kpt1.x(), kpt1.y()), (kpt2.x(), kpt2.y()), image::Rgb([233, 14, 57]), ); } } } } // draw mask if let Some(masks) = y.masks() { for (mask, _bbox) in masks.iter().zip(y.bboxes().unwrap().iter()) { let mask_nd: ImageBuffer, Vec> = match ImageBuffer::from_vec(img.width(), img.height(), mask.to_vec()) { Some(image) => image, None => panic!("can not crate image from ndarray"), }; for _x in 0..img.width() { for _y in 0..img.height() { let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_nd, _x, _y); if mask_p.0[0] > 0 { let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, _x, _y); // img_p.0[2] = self.color_palette[bbox.id()].2 / 2; // img_p.0[1] = self.color_palette[bbox.id()].1 / 2; // img_p.0[0] = self.color_palette[bbox.id()].0 / 2; img_p.0[2] /= 2; img_p.0[1] = 255 - (255 - img_p.0[2]) / 2; img_p.0[0] /= 2; imageproc::drawing::Canvas::draw_pixel(&mut img, _x, _y, img_p) } } } } } // mkdir and save let mut runs = PathBuf::from("runs"); if !runs.exists() { std::fs::create_dir_all(&runs).unwrap(); } runs.push(gen_time_string("-")); let saveout = format!("{}.jpg", runs.to_str().unwrap()); let _ = img.save(saveout); } } pub fn summary(&self) { println!( "\nSummary:\n\ > Task: {:?}{}\n\ > EP: {:?} {}\n\ > Dtype: {:?}\n\ > Batch: {} ({}), Height: {} ({}), Width: {} ({})\n\ > nc: {} nk: {}, nm: {}, conf: {}, kconf: {}, iou: {}\n\ ", self.task(), match self.engine.author().zip(self.engine.version()) { Some((author, ver)) => format!(" ({} {})", author, ver), None => String::from(""), }, self.engine.ep(), if let OrtEP::Cpu = self.engine.ep() { "" } else { "(May still fall back to CPU)" }, self.engine.dtype(), self.batch(), if self.engine.is_batch_dynamic() { "Dynamic" } else { "Const" }, self.height(), if self.engine.is_height_dynamic() { "Dynamic" } else { "Const" }, self.width(), if self.engine.is_width_dynamic() { "Dynamic" } else { "Const" }, self.nc(), self.nk(), self.nm(), self.conf, self.kconf, self.iou, ); } pub fn engine(&self) -> &OrtBackend { &self.engine } pub fn conf(&self) -> f32 { self.conf } pub fn set_conf(&mut self, val: f32) { self.conf = val; } pub fn conf_mut(&mut self) -> &mut f32 { &mut self.conf } pub fn kconf(&self) -> f32 { self.kconf } pub fn iou(&self) -> f32 { self.iou } pub fn task(&self) -> &YOLOTask { &self.task } pub fn batch(&self) -> u32 { self.batch } pub fn width(&self) -> u32 { self.width } pub fn height(&self) -> u32 { self.height } pub fn nc(&self) -> u32 { self.nc } pub fn nk(&self) -> u32 { self.nk } pub fn nm(&self) -> u32 { self.nm } pub fn names(&self) -> &Vec { &self.names } }