Shortcuts

Source code for rising.ops.tensor

from typing import Optional

import numpy as np
import torch


[docs]def torch_one_hot(target: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: """ Compute one hot encoding of input tensor Args: target: tensor to be converted num_classes: number of classes. If :attr:`num_classes` is None, the maximum of target is used Returns: torch.Tensor: one hot encoded tensor """ if num_classes is None: num_classes = int(target.max().detach().item() + 1) dtype, device = target.dtype, target.device target_onehot = torch.zeros(*target.shape, num_classes, dtype=dtype, device=device) return target_onehot.scatter_(1, target.unsqueeze_(1), 1.0)
[docs]def np_one_hot(target: np.ndarray, num_classes: Optional[int] = None) -> np.ndarray: """ Compute one hot encoding of input array Args: target: array to be converted num_classes: number of classes Returns: numpy.ndarray: one hot encoded array """ if num_classes is None: num_classes = int(target.max().item() + 1) dtype = target.dtype target_onehot = np.zeros((*target.shape, num_classes), dtype=dtype) for c in range(num_classes): target_onehot[..., c] = target == c return target_onehot

© Copyright Copyright (c) 2019-2020, Justus Schock, Michael Baumgartner.. Revision 871c839e.

Read the Docs v: v0.2.0
Versions
latest
stable
v0.2.0
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.