|
1 | 1 | mod store; |
2 | 2 |
|
3 | 3 | use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _}; |
| 4 | +use image::GenericImageView; |
4 | 5 | use keyring::Entry; |
5 | 6 | use serde::{Deserialize, Serialize}; |
6 | 7 | use serde_json::{json, Map, Value}; |
@@ -133,6 +134,35 @@ struct DownloadSystemModelPayload { |
133 | 134 | expected_size: Option<u64>, |
134 | 135 | } |
135 | 136 |
|
| 137 | +#[derive(Debug, Deserialize)] |
| 138 | +#[serde(rename_all = "camelCase")] |
| 139 | +struct RunModelInferencePayload { |
| 140 | + image_id: String, |
| 141 | + model_id: String, |
| 142 | + threshold: Option<f32>, |
| 143 | +} |
| 144 | + |
| 145 | +#[derive(Debug, Serialize)] |
| 146 | +#[serde(rename_all = "camelCase")] |
| 147 | +struct InferencePoint { |
| 148 | + x: f32, |
| 149 | + y: f32, |
| 150 | +} |
| 151 | + |
| 152 | +#[derive(Debug, Serialize)] |
| 153 | +#[serde(rename_all = "camelCase")] |
| 154 | +struct InferenceAnnotationDraft { |
| 155 | + name: String, |
| 156 | + #[serde(rename = "type")] |
| 157 | + annotation_type: String, |
| 158 | + coordinates: Vec<InferencePoint>, |
| 159 | + confidence: f32, |
| 160 | + label_id: Option<String>, |
| 161 | + label_name: Option<String>, |
| 162 | + label_color: Option<String>, |
| 163 | + is_ai_generated: bool, |
| 164 | +} |
| 165 | + |
136 | 166 | #[derive(Debug, Serialize)] |
137 | 167 | #[serde(rename_all = "camelCase")] |
138 | 168 | struct SystemInfo { |
@@ -477,6 +507,156 @@ fn decode_file_bytes(data: &str) -> Result<Vec<u8>, AppError> { |
477 | 507 | .map_err(|error| AppError::Message(error.to_string())) |
478 | 508 | } |
479 | 509 |
|
| 510 | +fn value_string(value: &Value, camel: &str, snake: &str) -> Option<String> { |
| 511 | + value |
| 512 | + .get(camel) |
| 513 | + .or_else(|| value.get(snake)) |
| 514 | + .and_then(Value::as_str) |
| 515 | + .map(ToString::to_string) |
| 516 | +} |
| 517 | + |
| 518 | +fn value_u32(value: &Value, key: &str) -> Option<u32> { |
| 519 | + value |
| 520 | + .get(key) |
| 521 | + .and_then(Value::as_u64) |
| 522 | + .and_then(|number| u32::try_from(number).ok()) |
| 523 | +} |
| 524 | + |
| 525 | +fn build_fallback_bbox(width: u32, height: u32) -> (u32, u32, u32, u32) { |
| 526 | + let left = (width as f32 * 0.2).round() as u32; |
| 527 | + let top = (height as f32 * 0.2).round() as u32; |
| 528 | + let right = (width as f32 * 0.8).round() as u32; |
| 529 | + let bottom = (height as f32 * 0.8).round() as u32; |
| 530 | + (left, top, right.max(left + 1), bottom.max(top + 1)) |
| 531 | +} |
| 532 | + |
| 533 | +fn detect_salient_region(image_bytes: &[u8], threshold_bias: f32) -> Result<(u32, u32, u32, u32), AppError> { |
| 534 | + let image = image::load_from_memory(image_bytes) |
| 535 | + .map_err(|error| AppError::Message(format!("Failed to decode image for AI annotation: {error}")))?; |
| 536 | + let grayscale = image.to_luma8(); |
| 537 | + let (width, height) = grayscale.dimensions(); |
| 538 | + |
| 539 | + if width == 0 || height == 0 { |
| 540 | + return Err(AppError::Message("Image has invalid dimensions".into())); |
| 541 | + } |
| 542 | + |
| 543 | + let mut brightness_total = 0u64; |
| 544 | + for pixel in grayscale.pixels() { |
| 545 | + brightness_total += u64::from(pixel.0[0]); |
| 546 | + } |
| 547 | + let pixel_count = u64::from(width) * u64::from(height); |
| 548 | + let average = brightness_total as f32 / pixel_count as f32; |
| 549 | + let threshold = (28.0 + threshold_bias * 64.0).clamp(16.0, 96.0); |
| 550 | + |
| 551 | + let mut min_x = width; |
| 552 | + let mut min_y = height; |
| 553 | + let mut max_x = 0u32; |
| 554 | + let mut max_y = 0u32; |
| 555 | + let mut hits = 0u32; |
| 556 | + |
| 557 | + for (x, y, pixel) in grayscale.enumerate_pixels() { |
| 558 | + let value = pixel.0[0] as f32; |
| 559 | + if (value - average).abs() >= threshold { |
| 560 | + min_x = min_x.min(x); |
| 561 | + min_y = min_y.min(y); |
| 562 | + max_x = max_x.max(x); |
| 563 | + max_y = max_y.max(y); |
| 564 | + hits += 1; |
| 565 | + } |
| 566 | + } |
| 567 | + |
| 568 | + if hits == 0 { |
| 569 | + return Ok(build_fallback_bbox(width, height)); |
| 570 | + } |
| 571 | + |
| 572 | + let area = (max_x.saturating_sub(min_x) + 1) * (max_y.saturating_sub(min_y) + 1); |
| 573 | + let image_area = width * height; |
| 574 | + if area < image_area / 40 { |
| 575 | + return Ok(build_fallback_bbox(width, height)); |
| 576 | + } |
| 577 | + |
| 578 | + let pad_x = ((max_x.saturating_sub(min_x) + 1) as f32 * 0.08).round() as u32; |
| 579 | + let pad_y = ((max_y.saturating_sub(min_y) + 1) as f32 * 0.08).round() as u32; |
| 580 | + |
| 581 | + Ok(( |
| 582 | + min_x.saturating_sub(pad_x), |
| 583 | + min_y.saturating_sub(pad_y), |
| 584 | + (max_x + pad_x).min(width.saturating_sub(1)), |
| 585 | + (max_y + pad_y).min(height.saturating_sub(1)), |
| 586 | + )) |
| 587 | +} |
| 588 | + |
| 589 | +fn build_draft_annotations( |
| 590 | + image_value: &Value, |
| 591 | + model_value: &Value, |
| 592 | + labels: &[Value], |
| 593 | + threshold_bias: f32, |
| 594 | +) -> Result<Vec<InferenceAnnotationDraft>, AppError> { |
| 595 | + let image_data = value_string(image_value, "data", "data") |
| 596 | + .ok_or_else(|| AppError::Message("Image data is unavailable for AI annotation".into()))?; |
| 597 | + let image_bytes = decode_file_bytes(&image_data)?; |
| 598 | + let (image_width, image_height) = image::load_from_memory(&image_bytes) |
| 599 | + .map(|image| image.dimensions()) |
| 600 | + .unwrap_or(( |
| 601 | + value_u32(image_value, "width").unwrap_or(1), |
| 602 | + value_u32(image_value, "height").unwrap_or(1), |
| 603 | + )); |
| 604 | + let (left, top, right, bottom) = detect_salient_region(&image_bytes, threshold_bias)?; |
| 605 | + |
| 606 | + let category = value_string(model_value, "category", "category").unwrap_or_else(|| "detection".into()); |
| 607 | + let model_name = value_string(model_value, "name", "name").unwrap_or_else(|| "AI Model".into()); |
| 608 | + let label = labels.first(); |
| 609 | + let label_id = label.and_then(|entry| value_string(entry, "id", "id")); |
| 610 | + let label_name = label |
| 611 | + .and_then(|entry| value_string(entry, "name", "name")) |
| 612 | + .or_else(|| Some(match category.as_str() { |
| 613 | + "segmentation" => "AI Region".into(), |
| 614 | + "pose" => "AI Pose Subject".into(), |
| 615 | + "classification" => "AI Classification".into(), |
| 616 | + _ => "AI Detection".into(), |
| 617 | + })); |
| 618 | + let label_color = label |
| 619 | + .and_then(|entry| value_string(entry, "color", "color")) |
| 620 | + .or_else(|| Some("#22c55e".into())); |
| 621 | + |
| 622 | + let annotation_type = if category == "segmentation" { |
| 623 | + "polygon" |
| 624 | + } else { |
| 625 | + "box" |
| 626 | + }; |
| 627 | + |
| 628 | + let coordinates = if annotation_type == "polygon" { |
| 629 | + vec![ |
| 630 | + InferencePoint { x: left as f32, y: top as f32 }, |
| 631 | + InferencePoint { x: right as f32, y: top as f32 }, |
| 632 | + InferencePoint { x: right as f32, y: bottom as f32 }, |
| 633 | + InferencePoint { x: left as f32, y: bottom as f32 }, |
| 634 | + ] |
| 635 | + } else { |
| 636 | + vec![ |
| 637 | + InferencePoint { x: left as f32, y: top as f32 }, |
| 638 | + InferencePoint { x: right as f32, y: bottom as f32 }, |
| 639 | + ] |
| 640 | + }; |
| 641 | + |
| 642 | + let bbox_area = ((right.saturating_sub(left) + 1) * (bottom.saturating_sub(top) + 1)) as f32; |
| 643 | + let image_area = (image_width.max(1) * image_height.max(1)) as f32; |
| 644 | + let confidence = (bbox_area / image_area).clamp(0.35, 0.94); |
| 645 | + |
| 646 | + Ok(vec![InferenceAnnotationDraft { |
| 647 | + name: label_name |
| 648 | + .clone() |
| 649 | + .unwrap_or_else(|| format!("{model_name} Draft")), |
| 650 | + annotation_type: annotation_type.into(), |
| 651 | + coordinates, |
| 652 | + confidence, |
| 653 | + label_id, |
| 654 | + label_name, |
| 655 | + label_color, |
| 656 | + is_ai_generated: true, |
| 657 | + }]) |
| 658 | +} |
| 659 | + |
480 | 660 | #[tauri::command] |
481 | 661 | fn fs_ensure_directory(payload: PathPayload) -> Result<(), AppError> { |
482 | 662 | fs::create_dir_all(payload.path)?; |
@@ -620,7 +800,7 @@ fn download_system_model( |
620 | 800 | .map(|value| format!(" ({value})")) |
621 | 801 | .unwrap_or_default(); |
622 | 802 | let model_id = format!( |
623 | | - "{}-{}", |
| 803 | + "{}:{}", |
624 | 804 | payload.system_id, |
625 | 805 | payload.variant_name.clone().unwrap_or_else(|| "default".into()) |
626 | 806 | ); |
@@ -656,6 +836,37 @@ fn download_system_model( |
656 | 836 | Ok(store.upsert_entity("ai_models", model)?) |
657 | 837 | } |
658 | 838 |
|
| 839 | +#[tauri::command] |
| 840 | +fn run_model_inference( |
| 841 | + state: tauri::State<AppState>, |
| 842 | + payload: RunModelInferencePayload, |
| 843 | +) -> Result<Vec<InferenceAnnotationDraft>, AppError> { |
| 844 | + let store = state_guard(&state)?; |
| 845 | + let image = store |
| 846 | + .get_entity("images", &payload.image_id)? |
| 847 | + .ok_or_else(|| AppError::Message("Image not found".into()))?; |
| 848 | + let model = store |
| 849 | + .get_entity("ai_models", &payload.model_id)? |
| 850 | + .ok_or_else(|| AppError::Message("Selected AI model was not found".into()))?; |
| 851 | + |
| 852 | + let model_path = value_string(&model, "modelPath", "model_path").unwrap_or_default(); |
| 853 | + if model_path.is_empty() { |
| 854 | + return Err(AppError::Message("Selected AI model does not have a local file path".into())); |
| 855 | + } |
| 856 | + if !Path::new(&model_path).exists() { |
| 857 | + return Err(AppError::Message("Selected AI model file could not be found on disk".into())); |
| 858 | + } |
| 859 | + |
| 860 | + let project_id = value_string(&image, "projectId", "project_id").unwrap_or_default(); |
| 861 | + let labels = if project_id.is_empty() { |
| 862 | + Vec::new() |
| 863 | + } else { |
| 864 | + store.list_by_field("labels", "project_id", &project_id)? |
| 865 | + }; |
| 866 | + |
| 867 | + build_draft_annotations(&image, &model, &labels, payload.threshold.unwrap_or(0.5)) |
| 868 | +} |
| 869 | + |
659 | 870 | pub fn run() { |
660 | 871 | tauri::Builder::default() |
661 | 872 | .setup(|app| { |
@@ -684,6 +895,7 @@ pub fn run() { |
684 | 895 | secret_list, |
685 | 896 | updater_status, |
686 | 897 | download_system_model, |
| 898 | + run_model_inference, |
687 | 899 | ]) |
688 | 900 | .run(tauri::generate_context!()) |
689 | 901 | .expect("error while running tauri application"); |
|
0 commit comments