Source code for robotblockset.cameras.image_transform

"""Image transform abstractions and helpers."""

from abc import ABC
from typing import Tuple, Union, List
import cv2
import numpy as np

from robotblockset.rbs_typing import NumpyFloatImageType, NumpyIntImageType, OpenCVIntImageType

HWCImageType = Union[OpenCVIntImageType, NumpyFloatImageType, NumpyIntImageType]
"""an image with shape (H,W,C)"""

ImageShapeType = Union[Tuple[int, int, int], Tuple[int, int]]
ImagePointType = Union[Tuple[int, int], Tuple[float, float]]


[docs] class ImageTransform(ABC): def __init__(self, input_shape: ImageShapeType): self._input_shape = input_shape @property def _input_h(self) -> int: return self._input_shape[0] @property def _input_w(self) -> int: return self._input_shape[1] @property def shape(self) -> ImageShapeType: """The shape of the transformed image. Returns ------- ImageShapeType: The shape of the transformed image. """ raise NotImplementedError
[docs] def transform_image(self, image: HWCImageType) -> HWCImageType: """Apply the image transform to an image to get a new image. Parameters ---------- image : HWCImageType The original image, it will be unaffected by the transform. Raises ------ NotImplementedError: Subclasses must implement this method. Returns ------- HWCImageType: The new, transformed image. """ raise NotImplementedError
[docs] def transform_point(self, point: ImagePointType) -> ImagePointType: """Map a point into transformed-image coordinates. Transform the coordinates of a point from original image to transformed image.""" raise NotImplementedError
[docs] def reverse_transform_point(self, point: ImagePointType) -> ImagePointType: """Map a transformed-image point back to the source. Transform the coordinates of a point in the transformed image back to the original image.""" raise NotImplementedError
def __call__(self, image: HWCImageType) -> HWCImageType: """Shorthand to transform an image.""" return self.transform_image(image)
[docs] class ComposedTransform(ImageTransform): def __init__(self, transforms: List[ImageTransform]): if len(transforms) == 0: raise ValueError("transforms must be a non-empty list.") super().__init__(transforms[0]._input_shape) self.transforms = transforms @property def shape(self) -> ImageShapeType: return self.transforms[-1].shape
[docs] def transform_image(self, image: HWCImageType) -> HWCImageType: for transform in self.transforms: image = transform.transform_image(image) return image
[docs] def transform_point(self, point: ImagePointType) -> ImagePointType: for transform in self.transforms: point = transform.transform_point(point) print(point) return point
[docs] def reverse_transform_point(self, point: ImagePointType) -> ImagePointType: for transform in reversed(self.transforms): point = transform.reverse_transform_point(point) return point
[docs] def crop(image: HWCImageType, x: int, y: int, w: int, h: int) -> HWCImageType: """ Crop a rectangular region from an image. Parameters ---------- image : HWCImageType Image to crop. x : int X-coordinate of the top-left crop corner. y : int Y-coordinate of the top-left crop corner. w : int Crop width in pixels. h : int Crop height in pixels. """ # Note that the first index of the array is the y-coordinate, because this indexes the rows of the image and the y-axis runs from top to bottom. if len(image.shape) == 2: return image[y : y + h, x : x + w].copy() return image[y : y + h, x : x + w, :].copy()
[docs] class Crop(ImageTransform): """""" def __init__(self, input_shape: ImageShapeType, x: int, y: int, w: int, h: int): super().__init__(input_shape) self.x = x self.y = y self.w = w self.h = h @property def shape(self) -> ImageShapeType: if len(self._input_shape) == 2: return self.h, self.w c = self._input_shape[2] return self.h, self.w, c
[docs] def transform_image(self, image: HWCImageType) -> HWCImageType: return crop(image, self.x, self.y, self.w, self.h)
[docs] def transform_point(self, point: ImagePointType) -> ImagePointType: x, y = point if not (x >= self.x and x < self.x + self.w): raise ValueError(f"x-coordinate {x} is outside of the crop range [{self.x}, {self.x + self.w})") if not (y >= self.y and y < self.y + self.h): raise ValueError(f"y-coordinate {y} is outside of the crop range [{self.y}, {self.y + self.h})") return x - self.x, y - self.y
[docs] def reverse_transform_point(self, point: ImagePointType) -> ImagePointType: x, y = point if not (x >= 0 and x < self.w): raise ValueError(f"x-coordinate {x} is outside of the crop range [0, {self.w})") if not (y >= 0 and y < self.h): raise ValueError(f"y-coordinate {y} is outside of the crop range [0, {self.h})") return x + self.x, y + self.y
[docs] class Resize(ImageTransform):
[docs] def __init__(self, input_shape: ImageShapeType, h: int, w: int, round_transformed_points: bool = True): """Create a new Resize transform. Note: Transforming a point to or from a resized image can lead to non-integer coordinates. Pixel coordinates are however often expected to be integers, e.g. by the OpenCV draw functions. So by default, this class will round transformed points to the nearest integer. If you want to avoid the errors introduced by rounding, you can set `round_transformed_points` to False to get the exact transformed points as floats. Parameters ---------- input_shape : ImageShapeType Shape of the images that will be resized. h : int Height of the resized image. w : int Width of the resized image. round_transformed_points : bool, optional Whether to round transformed points to the nearest integer. """ super().__init__(input_shape) self.h = h self.w = w self.round_transformed_points = round_transformed_points
@property def shape(self) -> ImageShapeType: if len(self._input_shape) == 2: return self.h, self.w c = self._input_shape[2] return self.h, self.w, c
[docs] def transform_image(self, image: HWCImageType) -> HWCImageType: return cv2.resize(image, (self.w, self.h))
[docs] def transform_point(self, point: ImagePointType) -> ImagePointType: x, y = point if not (x >= 0 and x < self._input_w): raise ValueError(f"x-coordinate {x} is outside of the input image range [0, {self._input_w})") if not (y >= 0 and y < self._input_h): raise ValueError(f"y-coordinate {y} is outside of the input image range [0, {self._input_h})") w_scale = self.w / self._input_w h_scale = self.h / self._input_h x_float = w_scale * x y_float = h_scale * y if self.round_transformed_points: return round(x_float), round(y_float) return x_float, y_float
[docs] def reverse_transform_point(self, point: ImagePointType) -> ImagePointType: x, y = point if not (x >= 0 and x < self.w): raise ValueError(f"x-coordinate {x} is outside of the resized image range [0, {self.w})") if not (y >= 0 and y < self.h): raise ValueError(f"y-coordinate {y} is outside of the resized image range [0, {self.h})") w_scale_inverse = self._input_w / self.w h_scale_inverse = self._input_h / self.h x_float = w_scale_inverse * x y_float = h_scale_inverse * y if self.round_transformed_points: return round(x_float), round(y_float) return x_float, y_float
[docs] class Rotate90(ImageTransform): """Rotate an image by multiples of 90 degrees."""
[docs] def __init__( self, input_shape: ImageShapeType, num_rotations: int = 1, ): """Create a new Rotate transform. Parameters ---------- num_rotations : int, optional the number of 90-degree rotations to apply. Positive values rotate counter-clockwise. """ super().__init__(input_shape) if not isinstance(num_rotations, int): raise TypeError("num_rotations must be an int") self._num_rotations = num_rotations % 4
@property def shape(self) -> ImageShapeType: if self._num_rotations % 2 == 0: h, w = self._input_shape[:2] else: w, h = self._input_shape[:2] if len(self._input_shape) == 2: return h, w c = self._input_shape[2] return h, w, c
[docs] def transform_image(self, image: HWCImageType) -> HWCImageType: # The copy here ensure the result is not a view into the original image. return np.rot90(image, self._num_rotations).copy()
[docs] def transform_point(self, point: ImagePointType) -> ImagePointType: x, y = point if not (x >= 0 and x < self._input_w): raise ValueError(f"x-coordinate {x} is outside of the input image range [0, {self._input_w})") if not (y >= 0 and y < self._input_h): raise ValueError(f"y-coordinate {y} is outside of the input image range [0, {self._input_h})") if self._num_rotations == 1: return y, self._input_w - x - 1 elif self._num_rotations == 2: return self._input_w - x - 1, self._input_h - y - 1 elif self._num_rotations == 3: return self._input_h - y - 1, x return x, y
[docs] def reverse_transform_point(self, point: ImagePointType) -> ImagePointType: x, y = point if self._num_rotations == 1: return self._input_w - y - 1, x elif self._num_rotations == 2: return self._input_w - x - 1, self._input_h - y - 1 elif self._num_rotations == 3: return y, self._input_h - x - 1 return x, y