CenterCrop
https://pytorch.org/vision/stable/generated/torchvision.transforms.CenterCrop.html
Heard about this term at NVIDIA.
Implementing CenterCrop is actually kind of annoying.
Doing the math for this
Source code implementation for augment.py from Ultralytics (NOT USED)
# NOTE: keep this class for backward compatibility
class CenterCrop:
"""YOLOv8 CenterCrop class for image preprocessing, designed to be part of a transformation pipeline, e.g.,
T.Compose([CenterCrop(size), ToTensor()]).
"""
def __init__(self, size=640):
"""Converts an image from numpy array to PyTorch tensor."""
super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size
def __call__(self, im):
"""
Resizes and crops the center of the image using a letterbox method.
Args:
im (numpy.ndarray): The input image as a numpy array of shape HWC.
Returns:
(numpy.ndarray): The center-cropped and resized image as a numpy array.
"""
imh, imw = im.shape[:2]
m = min(imh, imw) # min dimension
top, left = (imh - m) // 2, (imw - m) // 2
return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
They actually use the torch
centercrop version:
tfl += [T.CenterCrop(size)]
Source code https://pytorch.org/vision/main/_modules/torchvision/transforms/transforms.html#CenterCrop
Calls center_crop
→ https://pytorch.org/vision/main/_modules/torchvision/transforms/functional.html#center_crop
[docs]def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
"""Crops the given image at the center.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
Args:
img (PIL Image or Tensor): Image to be cropped.
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
it is used for both directions.
Returns:
PIL Image or Tensor: Cropped image.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(center_crop)
if isinstance(output_size, numbers.Number):
output_size = (int(output_size), int(output_size))
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
output_size = (output_size[0], output_size[0])
_, image_height, image_width = get_dimensions(img)
crop_height, crop_width = output_size
if crop_width > image_width or crop_height > image_height:
padding_ltrb = [
(crop_width - image_width) // 2 if crop_width > image_width else 0,
(crop_height - image_height) // 2 if crop_height > image_height else 0,
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
]
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
_, image_height, image_width = get_dimensions(img)
if crop_width == image_width and crop_height == image_height:
return img
crop_top = int(round((image_height - crop_height) / 2.0))
crop_left = int(round((image_width - crop_width) / 2.0))
return crop(img, crop_top, crop_left, crop_height, crop_width)