diff --git a/models/utils/detect_face.py b/models/utils/detect_face.py index 5d144864..67b8a469 100644 --- a/models/utils/detect_face.py +++ b/models/utils/detect_face.py @@ -307,22 +307,72 @@ def imresample(img, sz): def crop_resize(img, box, image_size): + """ + box: (x1, y1, x2, y2) in pixel coords, x2/y2 exclusive-style is fine too (we resize anyway). + img: numpy HWC, torch HWC or CHW, or PIL Image + """ + + x1, y1, x2, y2 = map(int, box) + w = max(1, x2 - x1) + h = max(1, y2 - y1) + + s = max(w, h) + cx = x1 + w / 2.0 + cy = y1 + h / 2.0 + + # square window [x0, x0+s), [y0, y0+s) + x0 = int(round(cx - s / 2.0)) + y0 = int(round(cy - s / 2.0)) + if isinstance(img, np.ndarray): - img = img[box[1]:box[3], box[0]:box[2]] - out = cv2.resize( - img, - (image_size, image_size), - interpolation=cv2.INTER_AREA - ).copy() + H, W = img.shape[:2] elif isinstance(img, torch.Tensor): - img = img[box[1]:box[3], box[0]:box[2]] - out = imresample( - img.permute(2, 0, 1).unsqueeze(0).float(), - (image_size, image_size) - ).byte().squeeze(0).permute(1, 2, 0) + # accept HWC or CHW + if img.ndim != 3: + raise ValueError("torch img must be 3D (HWC or CHW)") + if img.shape[0] in (1, 3, 4) and img.shape[2] not in (1, 3, 4): + # CHW + C, H, W = img.shape + chw = True + else: + # HWC + H, W, C = img.shape + chw = False else: - out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR) - return out + # PIL + W, H = img.size + + # shift window to stay inside image (keeps square) + x0 = min(max(0, x0), max(0, W - s)) + y0 = min(max(0, y0), max(0, H - s)) + x1n, y1n = x0 + s, y0 + s + + if isinstance(img, np.ndarray): + crop = img[y0:y1n, x0:x1n] + return cv2.resize(crop, (image_size, image_size), interpolation=cv2.INTER_AREA).copy() + + if isinstance(img, torch.Tensor): + if chw: + crop = img[:, y0:y1n, x0:x1n] + else: + crop = img[y0:y1n, x0:x1n, :] + + # simplest: use torch.nn.functional.interpolate on float + import torch.nn.functional as F + if chw: + crop_f = crop.unsqueeze(0).float() + else: + crop_f = crop.permute(2, 0, 1).unsqueeze(0).float() + + out = F.interpolate(crop_f, size=(image_size, image_size), mode="area") + out = out.squeeze(0) + if not chw: + out = out.permute(1, 2, 0) + return out.byte() + + # PIL + crop = img.crop((x0, y0, x1n, y1n)) + return crop.resize((image_size, image_size), Image.BILINEAR) def save_img(img, path):