-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinfer.py
More file actions
46 lines (39 loc) · 1.72 KB
/
infer.py
File metadata and controls
46 lines (39 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from os import PathLike
from typing import override
import cv2
import numpy as np
import torch
from mipcandy import HasDevice, Device, convert_logits_to_ids, auto_device
from torchvision.transforms import Resize
from sort_screws import Camera, ResNetPredictor, cv2pt
class Predictor(Camera, HasDevice):
def __init__(self, experiment_folder: str | PathLike[str], *, device: Device = "cpu") -> None:
Camera.__init__(self, 512)
HasDevice.__init__(self, device)
self.predictor: ResNetPredictor = ResNetPredictor(str(experiment_folder), (3, 512, 512),
device=device)
self.paused: bool = False
self.resize: Resize = Resize(224)
@override
def job(self, frame: np.ndarray, roi: np.ndarray, bbox: tuple[int, int, int, int]) -> bool:
if not self.paused:
x1, y1, x2, y2 = bbox
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
image = cv2pt(roi, device=self._device)
image = self.resize(image)
logits = self.predictor.predict_image(image)
class_id = convert_logits_to_ids(logits, channel_dim=0).item()
cv2.putText(frame, f"Class: {class_id}", (40, 80), cv2.FONT_HERSHEY_COMPLEX, 2, (0, 255, 0), 2,
cv2.LINE_AA)
cv2.imshow("Camera Preview", frame)
key = self.wait_key()
if key == ord(" "):
self.paused = not self.paused
if key == ord("q"):
return True
return False
if __name__ == "__main__":
device = auto_device()
print(device)
app = Predictor(f"trainer/{ResNetPredictor.__name__.replace("Predictor", "Trainer")}/final", device=device)
app.run()