Shortcuts

Source code for rising.transforms.functional.channel

from typing import Optional

import torch

from rising.ops import torch_one_hot

__all__ = ["one_hot_batch"]


[docs]def one_hot_batch( target: torch.Tensor, num_classes: Optional[int] = None, dtype: Optional[torch.dtype] = None ) -> torch.Tensor: """ Compute one hot for input tensor (assumed to a be batch and thus saved into first dimension -> input should only have one channel) Args: target: long tensor to be converted num_classes: number of classes. If :attr:`num_classes` is None, the maximum of target is used dtype: optionally changes the dtype of the onehot encoding Returns: torch.Tensor: one hot encoded tensor """ if target.dtype != torch.long: raise TypeError(f"Target tensor needs to be of type torch.long, found {target.dtype}") if target.ndim in [0, 1]: return torch_one_hot(target, num_classes) else: if num_classes is None: num_classes = int(target.max().detach().item() + 1) _dtype, device, shape = target.dtype, target.device, target.shape if dtype is None: dtype = _dtype target_onehot = torch.zeros(shape[0], num_classes, *shape[2:], dtype=dtype, device=device) return target_onehot.scatter_(1, target, 1.0)

© Copyright Copyright (c) 2019-2020, Justus Schock, Michael Baumgartner.. Revision ab6fbcfe.

Read the Docs v: latest
Versions
latest
stable
v0.3.0
v0.2.2
v0.2.1
v0.2.0post0
v0.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.