Shortcuts

Source code for rising.transforms.format

from typing import Callable, Dict, Hashable, Mapping, Sequence, Tuple, Union

from rising.transforms.functional.utility import filter_keys, pop_keys

from .abstract import AbstractTransform

__all__ = ["MapToSeq", "SeqToMap", "PopKeys", "FilterKeys", "RenameKeys"]


[docs]class MapToSeq(AbstractTransform): """ Convert dict to sequence """ def __init__(self, *keys, grad: bool = False, **kwargs): """ Args: keys: keys which are mapped into sequence. grad: enable gradient computation inside transformation ** kwargs: additional keyword arguments passed to superclass """ super().__init__(grad=grad, **kwargs) if isinstance(keys[0], (list, tuple)): keys = keys[0] self.keys = keys
[docs] def forward(self, **data) -> tuple: """ Convert input Args: data: input dict Returns: tuple: mapped data """ return tuple(data[_k] for _k in self.keys)
[docs]class SeqToMap(AbstractTransform): """Convert sequence to dict""" def __init__(self, *keys, grad: bool = False, **kwargs): """ Args: keys: keys which are mapped into dict. grad: enable gradient computation inside transformation **kwargs: additional keyword arguments passed to superclass """ super().__init__(grad=grad, **kwargs) if isinstance(keys[0], (list, tuple)): keys = keys[0] self.keys = keys
[docs] def forward(self, *data, **kwargs) -> dict: """ Convert input Args: data: input tuple Returns: dict: mapped data """ return {_key: data[_idx] for _idx, _key in enumerate(self.keys)}
[docs]class PopKeys(AbstractTransform): """ Pops keys from a given data dict """ def __init__(self, keys: Union[Callable, Sequence], return_popped: bool = False): """ Args: keys : if callable it must return a boolean for each key indicating whether it should be popped from the dict. if sequence of strings, the strings shall be the keys to be poppedAbstractTransform, return_popped: whether to also return the popped values (default: False) """ super().__init__(grad=False) self.keys = keys self.return_popped = return_popped
[docs] def forward(self, **data) -> Union[dict, Tuple[dict, dict]]: return pop_keys(data=data, keys=self.keys, return_popped=self.return_popped)
[docs]class FilterKeys(AbstractTransform): """ Filters keys from a given data dict """ def __init__(self, keys: Union[Callable, Sequence], return_popped: bool = False): """ Args: keys: if callable it must return a boolean for each key indicating whether it should be retained in the dict. if sequence of strings, the strings shall be the keys to be retained return_popped: whether to also return the popped values (default: False) """ super().__init__(grad=False) self.keys = keys self.return_popped = return_popped
[docs] def forward(self, **data) -> Union[dict, Tuple[dict, dict]]: return filter_keys(data=data, keys=self.keys, return_popped=self.return_popped)
[docs]class RenameKeys(AbstractTransform): """Rename keys inside batch""" def __init__(self, keys: Mapping[Hashable, Hashable]): """ Args: keys: keys of mapping define current name and items define the new names """ super().__init__(grad=False) self.keys = keys
[docs] def forward(self, **data) -> Dict: for old_key, new_key in self.keys.items(): data[new_key] = data.pop(old_key) return data

© Copyright Copyright (c) 2019-2020, Justus Schock, Michael Baumgartner.. Revision b9cd7e8f.

Read the Docs v: stable
Versions
latest
stable
v0.3.0
v0.2.2
v0.2.1
v0.2.0post0
v0.2.0
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.