Source code for rising.random.continuous

from typing import Union

import torch
from torch.distributions import Distribution as TorchDistribution

from rising.random.abstract import AbstractParameter

__all__ = [

[docs]class ContinuousParameter(AbstractParameter): """Class to perform parameter sampling from torch distributions""" def __init__(self, distribution: TorchDistribution): """ Args: distribution : the distribution to sample from """ super().__init__() self.dist = distribution
[docs] def sample(self, n_samples: int) -> torch.Tensor: """ Samples from the internal distribution Args: n_samples : the number of elements to sample Returns torch.Tensor: samples """ return self.dist.sample((n_samples,))
[docs]class NormalParameter(ContinuousParameter): """ Samples Parameters from a normal distribution. For details have a look at :class:`torch.distributions.Normal` """ def __init__(self, mu: Union[float, torch.Tensor], sigma: Union[float, torch.Tensor]): """ Args: mu : the distributions mean sigma : the distributions standard deviation """ super().__init__(torch.distributions.Normal(loc=mu, scale=sigma))
[docs]class UniformParameter(ContinuousParameter): """ Samples Parameters from a uniform distribution. For details have a look at :class:`torch.distributions.Uniform` """ def __init__(self, low: Union[float, torch.Tensor], high: Union[float, torch.Tensor]): """ Args: low : the lower range (inclusive) high : the higher range (exclusive) """ super().__init__(torch.distributions.Uniform(low=low, high=high))

© Copyright Copyright (c) 2019-2020, Justus Schock, Michael Baumgartner.. Revision 871c839e.

Read the Docs v: v0.2.0
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.