Source code for albumentations.core.transforms_interface

import random

import cv2

__all__ = ['to_tuple', 'BasicTransform', 'DualTransform', 'ImageOnlyTransform', 'NoOp']


def to_tuple(param, low=None):
    if isinstance(param, (list, tuple)):
        return tuple(param)
    elif param is not None:
        if low is None:
            return -param, param
        return (low, param) if low < param else (param, low)
    else:
        return param


class BasicTransform(object):
    def __init__(self, always_apply=False, p=0.5):
        self.p = p
        self.always_apply = always_apply
        self._additional_targets = {}

    def __call__(self, force_apply=False, **kwargs):
        if (random.random() < self.p) or self.always_apply or force_apply:
            params = self.get_params()
            params = self.update_params(params, **kwargs)
            if self.targets_as_params:
                targets_as_params = {k: kwargs[k] for k in self.targets_as_params}
                params_dependent_on_targets = self.get_params_dependent_on_targets(targets_as_params)
                params.update(params_dependent_on_targets)
            res = {}
            for key, arg in kwargs.items():
                if arg is not None:
                    target_function = self._get_target_function(key)
                    target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])}
                    res[key] = target_function(arg, **dict(params, **target_dependencies))
                else:
                    res[key] = None
            return res
        return kwargs

    def _get_target_function(self, key):
        transform_key = key
        if key in self._additional_targets:
            transform_key = self._additional_targets.get(key, None)

        target_function = self.targets.get(transform_key, lambda x, **p: x)
        return target_function

    def apply(self, img, **params):
        raise NotImplementedError

    def get_params(self):
        return {}

    @property
    def targets(self):
        # you must specify targets in subclass
        # for example: ('image', 'mask')
        #              ('image', 'boxes')
        raise NotImplementedError

    def update_params(self, params, **kwargs):
        if hasattr(self, 'interpolation'):
            params['interpolation'] = self.interpolation
        params.update({'cols': kwargs['image'].shape[1], 'rows': kwargs['image'].shape[0]})
        return params

    @property
    def target_dependence(self):
        return {}

    def add_targets(self, additional_targets):
        """Add targets to transform them the same way as one of existing targets
        ex: {'target_image': 'image'}
        ex: {'obj1_mask': 'mask', 'obj2_mask': 'mask'}
        by the way you must have at least one object with key 'image'

        Args:
            additional_targets (dict): keys - new target name, values - old target name. ex: {'image2': 'image'}
        """
        self._additional_targets = additional_targets

    @property
    def targets_as_params(self):
        return []

    def get_params_dependent_on_targets(self, params):
        raise NotImplementedError('Method  get_params_dependent_on_targets is not implemented in class ' +
                                  self.__class__.__name__)


[docs]class DualTransform(BasicTransform): """Transform for segmentation task.""" @property def targets(self): return {'image': self.apply, 'mask': self.apply_to_mask, 'masks': self.apply_to_masks, 'bboxes': self.apply_to_bboxes, 'keypoints': self.apply_to_keypoints} def apply_to_bbox(self, bbox, **params): raise NotImplementedError('Method apply_to_bbox is not implemented in class ' + self.__class__.__name__) def apply_to_keypoint(self, keypoint, **params): raise NotImplementedError('Method apply_to_keypoint is not implemented in class ' + self.__class__.__name__) def apply_to_bboxes(self, bboxes, **params): bboxes = [list(bbox) for bbox in bboxes] return [self.apply_to_bbox(bbox[:4], **params) + bbox[4:] for bbox in bboxes] def apply_to_keypoints(self, keypoints, **params): keypoints = [list(keypoint) for keypoint in keypoints] return [self.apply_to_keypoint(keypoint[:4], **params) + keypoint[4:] for keypoint in keypoints] def apply_to_mask(self, img, **params): return self.apply(img, **{k: cv2.INTER_NEAREST if k == 'interpolation' else v for k, v in params.items()}) def apply_to_masks(self, masks, **params): return [self.apply_to_mask(mask, **params) for mask in masks]
[docs]class ImageOnlyTransform(BasicTransform): """Transform applied to image only.""" @property def targets(self): return {'image': self.apply}
[docs]class NoOp(DualTransform): """Does nothing""" def apply_to_keypoint(self, keypoint, **params): return keypoint def apply_to_bbox(self, bbox, **params): return bbox def apply(self, img, **params): return img def apply_to_mask(self, img, **params): return img