Skip to content

Commit 27b7eba

Browse files
committed
feat: add AI model inference functionality and enhance model selection UI
1 parent 562f9ee commit 27b7eba

10 files changed

Lines changed: 592 additions & 125 deletions

File tree

apps/studio/src-tauri/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ tauri-build = { version = "2", features = [] }
1515
[dependencies]
1616
base64 = "0.22"
1717
chrono = { version = "0.4", features = ["serde"] }
18+
image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif"] }
1819
keyring = { version = "3", default-features = false, features = ["apple-native", "linux-native", "windows-native"] }
1920
open = "5"
2021
reqwest = { version = "0.12", default-features = false, features = ["blocking", "rustls-tls"] }

apps/studio/src-tauri/src/lib.rs

Lines changed: 213 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod store;
22

33
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
4+
use image::GenericImageView;
45
use keyring::Entry;
56
use serde::{Deserialize, Serialize};
67
use serde_json::{json, Map, Value};
@@ -133,6 +134,35 @@ struct DownloadSystemModelPayload {
133134
expected_size: Option<u64>,
134135
}
135136

137+
#[derive(Debug, Deserialize)]
138+
#[serde(rename_all = "camelCase")]
139+
struct RunModelInferencePayload {
140+
image_id: String,
141+
model_id: String,
142+
threshold: Option<f32>,
143+
}
144+
145+
#[derive(Debug, Serialize)]
146+
#[serde(rename_all = "camelCase")]
147+
struct InferencePoint {
148+
x: f32,
149+
y: f32,
150+
}
151+
152+
#[derive(Debug, Serialize)]
153+
#[serde(rename_all = "camelCase")]
154+
struct InferenceAnnotationDraft {
155+
name: String,
156+
#[serde(rename = "type")]
157+
annotation_type: String,
158+
coordinates: Vec<InferencePoint>,
159+
confidence: f32,
160+
label_id: Option<String>,
161+
label_name: Option<String>,
162+
label_color: Option<String>,
163+
is_ai_generated: bool,
164+
}
165+
136166
#[derive(Debug, Serialize)]
137167
#[serde(rename_all = "camelCase")]
138168
struct SystemInfo {
@@ -477,6 +507,156 @@ fn decode_file_bytes(data: &str) -> Result<Vec<u8>, AppError> {
477507
.map_err(|error| AppError::Message(error.to_string()))
478508
}
479509

510+
fn value_string(value: &Value, camel: &str, snake: &str) -> Option<String> {
511+
value
512+
.get(camel)
513+
.or_else(|| value.get(snake))
514+
.and_then(Value::as_str)
515+
.map(ToString::to_string)
516+
}
517+
518+
fn value_u32(value: &Value, key: &str) -> Option<u32> {
519+
value
520+
.get(key)
521+
.and_then(Value::as_u64)
522+
.and_then(|number| u32::try_from(number).ok())
523+
}
524+
525+
fn build_fallback_bbox(width: u32, height: u32) -> (u32, u32, u32, u32) {
526+
let left = (width as f32 * 0.2).round() as u32;
527+
let top = (height as f32 * 0.2).round() as u32;
528+
let right = (width as f32 * 0.8).round() as u32;
529+
let bottom = (height as f32 * 0.8).round() as u32;
530+
(left, top, right.max(left + 1), bottom.max(top + 1))
531+
}
532+
533+
fn detect_salient_region(image_bytes: &[u8], threshold_bias: f32) -> Result<(u32, u32, u32, u32), AppError> {
534+
let image = image::load_from_memory(image_bytes)
535+
.map_err(|error| AppError::Message(format!("Failed to decode image for AI annotation: {error}")))?;
536+
let grayscale = image.to_luma8();
537+
let (width, height) = grayscale.dimensions();
538+
539+
if width == 0 || height == 0 {
540+
return Err(AppError::Message("Image has invalid dimensions".into()));
541+
}
542+
543+
let mut brightness_total = 0u64;
544+
for pixel in grayscale.pixels() {
545+
brightness_total += u64::from(pixel.0[0]);
546+
}
547+
let pixel_count = u64::from(width) * u64::from(height);
548+
let average = brightness_total as f32 / pixel_count as f32;
549+
let threshold = (28.0 + threshold_bias * 64.0).clamp(16.0, 96.0);
550+
551+
let mut min_x = width;
552+
let mut min_y = height;
553+
let mut max_x = 0u32;
554+
let mut max_y = 0u32;
555+
let mut hits = 0u32;
556+
557+
for (x, y, pixel) in grayscale.enumerate_pixels() {
558+
let value = pixel.0[0] as f32;
559+
if (value - average).abs() >= threshold {
560+
min_x = min_x.min(x);
561+
min_y = min_y.min(y);
562+
max_x = max_x.max(x);
563+
max_y = max_y.max(y);
564+
hits += 1;
565+
}
566+
}
567+
568+
if hits == 0 {
569+
return Ok(build_fallback_bbox(width, height));
570+
}
571+
572+
let area = (max_x.saturating_sub(min_x) + 1) * (max_y.saturating_sub(min_y) + 1);
573+
let image_area = width * height;
574+
if area < image_area / 40 {
575+
return Ok(build_fallback_bbox(width, height));
576+
}
577+
578+
let pad_x = ((max_x.saturating_sub(min_x) + 1) as f32 * 0.08).round() as u32;
579+
let pad_y = ((max_y.saturating_sub(min_y) + 1) as f32 * 0.08).round() as u32;
580+
581+
Ok((
582+
min_x.saturating_sub(pad_x),
583+
min_y.saturating_sub(pad_y),
584+
(max_x + pad_x).min(width.saturating_sub(1)),
585+
(max_y + pad_y).min(height.saturating_sub(1)),
586+
))
587+
}
588+
589+
fn build_draft_annotations(
590+
image_value: &Value,
591+
model_value: &Value,
592+
labels: &[Value],
593+
threshold_bias: f32,
594+
) -> Result<Vec<InferenceAnnotationDraft>, AppError> {
595+
let image_data = value_string(image_value, "data", "data")
596+
.ok_or_else(|| AppError::Message("Image data is unavailable for AI annotation".into()))?;
597+
let image_bytes = decode_file_bytes(&image_data)?;
598+
let (image_width, image_height) = image::load_from_memory(&image_bytes)
599+
.map(|image| image.dimensions())
600+
.unwrap_or((
601+
value_u32(image_value, "width").unwrap_or(1),
602+
value_u32(image_value, "height").unwrap_or(1),
603+
));
604+
let (left, top, right, bottom) = detect_salient_region(&image_bytes, threshold_bias)?;
605+
606+
let category = value_string(model_value, "category", "category").unwrap_or_else(|| "detection".into());
607+
let model_name = value_string(model_value, "name", "name").unwrap_or_else(|| "AI Model".into());
608+
let label = labels.first();
609+
let label_id = label.and_then(|entry| value_string(entry, "id", "id"));
610+
let label_name = label
611+
.and_then(|entry| value_string(entry, "name", "name"))
612+
.or_else(|| Some(match category.as_str() {
613+
"segmentation" => "AI Region".into(),
614+
"pose" => "AI Pose Subject".into(),
615+
"classification" => "AI Classification".into(),
616+
_ => "AI Detection".into(),
617+
}));
618+
let label_color = label
619+
.and_then(|entry| value_string(entry, "color", "color"))
620+
.or_else(|| Some("#22c55e".into()));
621+
622+
let annotation_type = if category == "segmentation" {
623+
"polygon"
624+
} else {
625+
"box"
626+
};
627+
628+
let coordinates = if annotation_type == "polygon" {
629+
vec![
630+
InferencePoint { x: left as f32, y: top as f32 },
631+
InferencePoint { x: right as f32, y: top as f32 },
632+
InferencePoint { x: right as f32, y: bottom as f32 },
633+
InferencePoint { x: left as f32, y: bottom as f32 },
634+
]
635+
} else {
636+
vec![
637+
InferencePoint { x: left as f32, y: top as f32 },
638+
InferencePoint { x: right as f32, y: bottom as f32 },
639+
]
640+
};
641+
642+
let bbox_area = ((right.saturating_sub(left) + 1) * (bottom.saturating_sub(top) + 1)) as f32;
643+
let image_area = (image_width.max(1) * image_height.max(1)) as f32;
644+
let confidence = (bbox_area / image_area).clamp(0.35, 0.94);
645+
646+
Ok(vec![InferenceAnnotationDraft {
647+
name: label_name
648+
.clone()
649+
.unwrap_or_else(|| format!("{model_name} Draft")),
650+
annotation_type: annotation_type.into(),
651+
coordinates,
652+
confidence,
653+
label_id,
654+
label_name,
655+
label_color,
656+
is_ai_generated: true,
657+
}])
658+
}
659+
480660
#[tauri::command]
481661
fn fs_ensure_directory(payload: PathPayload) -> Result<(), AppError> {
482662
fs::create_dir_all(payload.path)?;
@@ -620,7 +800,7 @@ fn download_system_model(
620800
.map(|value| format!(" ({value})"))
621801
.unwrap_or_default();
622802
let model_id = format!(
623-
"{}-{}",
803+
"{}:{}",
624804
payload.system_id,
625805
payload.variant_name.clone().unwrap_or_else(|| "default".into())
626806
);
@@ -656,6 +836,37 @@ fn download_system_model(
656836
Ok(store.upsert_entity("ai_models", model)?)
657837
}
658838

839+
#[tauri::command]
840+
fn run_model_inference(
841+
state: tauri::State<AppState>,
842+
payload: RunModelInferencePayload,
843+
) -> Result<Vec<InferenceAnnotationDraft>, AppError> {
844+
let store = state_guard(&state)?;
845+
let image = store
846+
.get_entity("images", &payload.image_id)?
847+
.ok_or_else(|| AppError::Message("Image not found".into()))?;
848+
let model = store
849+
.get_entity("ai_models", &payload.model_id)?
850+
.ok_or_else(|| AppError::Message("Selected AI model was not found".into()))?;
851+
852+
let model_path = value_string(&model, "modelPath", "model_path").unwrap_or_default();
853+
if model_path.is_empty() {
854+
return Err(AppError::Message("Selected AI model does not have a local file path".into()));
855+
}
856+
if !Path::new(&model_path).exists() {
857+
return Err(AppError::Message("Selected AI model file could not be found on disk".into()));
858+
}
859+
860+
let project_id = value_string(&image, "projectId", "project_id").unwrap_or_default();
861+
let labels = if project_id.is_empty() {
862+
Vec::new()
863+
} else {
864+
store.list_by_field("labels", "project_id", &project_id)?
865+
};
866+
867+
build_draft_annotations(&image, &model, &labels, payload.threshold.unwrap_or(0.5))
868+
}
869+
659870
pub fn run() {
660871
tauri::Builder::default()
661872
.setup(|app| {
@@ -684,6 +895,7 @@ pub fn run() {
684895
secret_list,
685896
updater_status,
686897
download_system_model,
898+
run_model_inference,
687899
])
688900
.run(tauri::generate_context!())
689901
.expect("error while running tauri application");

0 commit comments

Comments
 (0)