Source code for albumentations.pytorch.transforms

from __future__ import absolute_import

import numpy as np
import torch
from torchvision.transforms import functional as F

from ..core.transforms_interface import BasicTransform

__all__ = ['ToTensor']

def img_to_tensor(im, normalize=None):
    tensor = torch.from_numpy(np.moveaxis(im / (255. if im.dtype == np.uint8 else 1), -1, 0).astype(np.float32))
    if normalize is not None:
        return F.normalize(tensor, **normalize)
    return tensor

def mask_to_tensor(mask, num_classes, sigmoid):
    # todo
    if num_classes > 1:
        if not sigmoid:
            # softmax
            long_mask = np.zeros((mask.shape[:2]), dtype=np.int64)
            if len(mask.shape) == 3:
                for c in range(mask.shape[2]):
                    long_mask[mask[..., c] > 0] = c
                long_mask[mask > 127] = 1
                long_mask[mask == 0] = 0
            mask = long_mask
            mask = np.moveaxis(mask / (255. if mask.dtype == np.uint8 else 1), -1, 0).astype(np.float32)
        mask = np.expand_dims(mask / (255. if mask.dtype == np.uint8 else 1), 0).astype(np.float32)
    return torch.from_numpy(mask)

[docs]class ToTensor(BasicTransform): """Convert image and mask to `torch.Tensor` and divide by 255 if image or mask are `uint8` type. WARNING! Please use this with care and look into sources before usage. Args: num_classes (int): only for segmentation sigmoid (bool, optional): only for segmentation, transform mask to LongTensor or not. normalize (dict, optional): dict with keys [mean, std] to pass it into torchvision.normalize """ def __init__(self, num_classes=1, sigmoid=True, normalize=None): super(ToTensor, self).__init__(always_apply=True, p=1.) self.num_classes = num_classes self.sigmoid = sigmoid self.normalize = normalize def __call__(self, **kwargs): kwargs.update({'image': img_to_tensor(kwargs['image'], self.normalize)}) if 'mask' in kwargs.keys(): kwargs.update({'mask': mask_to_tensor(kwargs['mask'], self.num_classes, sigmoid=self.sigmoid)}) return kwargs @property def targets(self): raise NotImplementedError