Module rnnbuilder.custom
Functions to extend this library with custom modules
Expand source code
"""Functions to extend this library with custom modules"""
from typing import Type
import torch as _torch
from ._factories import _NonRecurrentFactory, _RecurrentFactory
from abc import ABC as _ABC, abstractmethod as _abstractmethod
__pdoc__ = {'CustomModule.training' : False, 'CustomModule.dump_patches' : False}
class CustomModule(_torch.nn.Module, _ABC):
"""Abstract base class for custom recurrent modules. Inheriting classes must be registered with `register_recurrent`
to retrieve a corresponding factory class. The factory class will initialize with the same parameters as the
registered CustomModule.
"""
def enter_in_shape(self, in_shape):
"""Gets called immediately after initialization. Overwrite if internal parameters depend on the input shape.
"""
pass
@_abstractmethod
def get_out_shape(self, in_shape):
"""
Args:
in_shape: data shape of incoming tensor (excluding time and batch dimensions) before any flattening is
applied
In most cases, one of three values should be returned:
1. 'in_shape' if the module does not perform changes to the shape of the data (flattening does not count)
2. some_fixed_shape if the output shape is independent of the input e.g. (out_shape,) for Linear layers
3. 'None' in all other cases (output shape will be inferred automatically)
"""
pass
def get_initial_output(self, full_shape):
"""Returns the initial output used for `Placeholder`s. This defaults to zero and can be overwritten by
individual `Placeholder`s.
Args:
full_shape: shape of the output in the format (batch, data0, data1, ...).
"""
return _torch.zeros(full_shape, device=self.device)
@_abstractmethod
def forward(self, input, state):
"""Must be implemented with the following signature.
Args:
input: the input tensor, the expected shape needs to be reported when registering the module, see
`register_recurrent`
state: some state as any combination of dicts, lists, tuples and tensors (needs to have format consistent
with `get_initial_state`). The empty tuple is used to indicate no state.
Returns:
(output, new_state) where
output: the output tensor, format needs to match input but data dimensions can be different
new_state: the new state in the same format as the input state
"""
pass
def get_initial_state(self, batch_size):
"""Returns initial state (in the same format as `forward`, for a single step, first dimension being batch size).
"""
return ()
class _NonRecurrentGenerator:
def __init__(self, make_module, prepare_input, shape_change):
self._make_module = make_module
self._shape_change = shape_change
self._prepare_input = prepare_input
def __call__(self, *args, **kwargs):
return _NonRecurrentFactory(self._make_module, self._prepare_input, self._shape_change, *args, **kwargs)
def register_non_recurrent(*, module_class: Type[_torch.nn.Module],
flatten_input: bool,
shape_change: bool):
"""Register a (non-recurrent) `torch.nn.Module` to retrieve a factory class. The factory class will initialize with
the same parameters as the registered Module. If this interface is too restrictive, wrap the Module in a
`CustomModule` and use `register_recurrent` instead.
Args:
module_class: a class derived from `torch.nn.Module`. The forward method needs to conform to a fixed signature.
It must accept a single input tensor of format (batch, data0, data1, ...) and return a single output tensor.
flatten_input: If True, the input is flattened to shape (batch, data0*data1*...) before given to the module.
shape_change: Indicate whether the module changes the shape of the tensor going through, i.e. put False if input
(after the optional flatten) and output shapes of the module are identical otherwise True
"""
initializer = (lambda in_shape, *args, **kwargs: module_class(*args, **kwargs))\
if type(module_class) is type else module_class
return _NonRecurrentGenerator(initializer, 'flatten' if flatten_input else 'keep',
'auto' if shape_change else 'none')
class _RecurrentGenerator:
def __init__(self, module_class, prepare_input, single_step, unroll_full_state):
self._module_class = module_class
self._prepare_input = prepare_input
self._single_step = single_step
self._unroll_full_state = unroll_full_state
def __call__(self, *args, **kwargs):
return _RecurrentFactory(self._module_class, self._prepare_input, self._single_step, self._unroll_full_state,
*args, **kwargs)
def register_recurrent(*, module_class: Type[CustomModule], flatten_input: bool,
single_step: bool, unroll_full_state: bool = True):
"""Register a (possibly recurrent) `CustomModule` to retrieve a factory class. The factory class will initialize
with the same parameters as the registered `CustomModule`.
Args:
module_class: a class derived from `CustomModule`
flatten_input: If True, the input is flattened to a single data dimension before given to the module.
single_step: If True, for each time step the forward method of the module will be invoked with a tensor of
format (batch, data_shape). If False, it will be only invoked once with a tensor of format
(time, batch, data_shape).
unroll_full_state: (Optional) If True and the state for every time step is required (full_state mode), the
module will be invoked for each time step even if single_step=False. If False, the module has to check for
self._full_state and return a sequence of states appropriately
"""
if module_class.__abstractmethods__:
raise Exception("Can't register "+module_class.__name__+". All abstract methods need to be overwritten.")
return _RecurrentGenerator(module_class, 'flatten' if flatten_input else 'keep', single_step, unroll_full_state)
Functions
def register_non_recurrent(*, module_class: Type[torch.nn.modules.module.Module], flatten_input: bool, shape_change: bool)
-
Register a (non-recurrent)
torch.nn.Module
to retrieve a factory class. The factory class will initialize with the same parameters as the registered Module. If this interface is too restrictive, wrap the Module in aCustomModule
and useregister_recurrent()
instead.Args
module_class
- a class derived from
torch.nn.Module
. The forward method needs to conform to a fixed signature. It must accept a single input tensor of format (batch, data0, data1, …) and return a single output tensor. flatten_input
- If True, the input is flattened to shape (batch, data0data1…) before given to the module.
shape_change
- Indicate whether the module changes the shape of the tensor going through, i.e. put False if input (after the optional flatten) and output shapes of the module are identical otherwise True
Expand source code
def register_non_recurrent(*, module_class: Type[_torch.nn.Module], flatten_input: bool, shape_change: bool): """Register a (non-recurrent) `torch.nn.Module` to retrieve a factory class. The factory class will initialize with the same parameters as the registered Module. If this interface is too restrictive, wrap the Module in a `CustomModule` and use `register_recurrent` instead. Args: module_class: a class derived from `torch.nn.Module`. The forward method needs to conform to a fixed signature. It must accept a single input tensor of format (batch, data0, data1, ...) and return a single output tensor. flatten_input: If True, the input is flattened to shape (batch, data0*data1*...) before given to the module. shape_change: Indicate whether the module changes the shape of the tensor going through, i.e. put False if input (after the optional flatten) and output shapes of the module are identical otherwise True """ initializer = (lambda in_shape, *args, **kwargs: module_class(*args, **kwargs))\ if type(module_class) is type else module_class return _NonRecurrentGenerator(initializer, 'flatten' if flatten_input else 'keep', 'auto' if shape_change else 'none')
def register_recurrent(*, module_class: Type[CustomModule], flatten_input: bool, single_step: bool, unroll_full_state: bool = True)
-
Register a (possibly recurrent)
CustomModule
to retrieve a factory class. The factory class will initialize with the same parameters as the registeredCustomModule
.Args
module_class
- a class derived from
CustomModule
flatten_input
- If True, the input is flattened to a single data dimension before given to the module.
single_step
- If True, for each time step the forward method of the module will be invoked with a tensor of format (batch, data_shape). If False, it will be only invoked once with a tensor of format (time, batch, data_shape).
unroll_full_state
- (Optional) If True and the state for every time step is required (full_state mode), the module will be invoked for each time step even if single_step=False. If False, the module has to check for self._full_state and return a sequence of states appropriately
Expand source code
def register_recurrent(*, module_class: Type[CustomModule], flatten_input: bool, single_step: bool, unroll_full_state: bool = True): """Register a (possibly recurrent) `CustomModule` to retrieve a factory class. The factory class will initialize with the same parameters as the registered `CustomModule`. Args: module_class: a class derived from `CustomModule` flatten_input: If True, the input is flattened to a single data dimension before given to the module. single_step: If True, for each time step the forward method of the module will be invoked with a tensor of format (batch, data_shape). If False, it will be only invoked once with a tensor of format (time, batch, data_shape). unroll_full_state: (Optional) If True and the state for every time step is required (full_state mode), the module will be invoked for each time step even if single_step=False. If False, the module has to check for self._full_state and return a sequence of states appropriately """ if module_class.__abstractmethods__: raise Exception("Can't register "+module_class.__name__+". All abstract methods need to be overwritten.") return _RecurrentGenerator(module_class, 'flatten' if flatten_input else 'keep', single_step, unroll_full_state)
Classes
class CustomModule
-
Abstract base class for custom recurrent modules. Inheriting classes must be registered with
register_recurrent()
to retrieve a corresponding factory class. The factory class will initialize with the same parameters as the registered CustomModule.Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class CustomModule(_torch.nn.Module, _ABC): """Abstract base class for custom recurrent modules. Inheriting classes must be registered with `register_recurrent` to retrieve a corresponding factory class. The factory class will initialize with the same parameters as the registered CustomModule. """ def enter_in_shape(self, in_shape): """Gets called immediately after initialization. Overwrite if internal parameters depend on the input shape. """ pass @_abstractmethod def get_out_shape(self, in_shape): """ Args: in_shape: data shape of incoming tensor (excluding time and batch dimensions) before any flattening is applied In most cases, one of three values should be returned: 1. 'in_shape' if the module does not perform changes to the shape of the data (flattening does not count) 2. some_fixed_shape if the output shape is independent of the input e.g. (out_shape,) for Linear layers 3. 'None' in all other cases (output shape will be inferred automatically) """ pass def get_initial_output(self, full_shape): """Returns the initial output used for `Placeholder`s. This defaults to zero and can be overwritten by individual `Placeholder`s. Args: full_shape: shape of the output in the format (batch, data0, data1, ...). """ return _torch.zeros(full_shape, device=self.device) @_abstractmethod def forward(self, input, state): """Must be implemented with the following signature. Args: input: the input tensor, the expected shape needs to be reported when registering the module, see `register_recurrent` state: some state as any combination of dicts, lists, tuples and tensors (needs to have format consistent with `get_initial_state`). The empty tuple is used to indicate no state. Returns: (output, new_state) where output: the output tensor, format needs to match input but data dimensions can be different new_state: the new state in the same format as the input state """ pass def get_initial_state(self, batch_size): """Returns initial state (in the same format as `forward`, for a single step, first dimension being batch size). """ return ()
Ancestors
- torch.nn.modules.module.Module
- abc.ABC
Subclasses
- rnnbuilder.rnn._modules._LSTMModule
- rnnbuilder.rnn._modules._TempConvPlus2dModule
- rnnbuilder.snn._modules._NoResetNeuron
Methods
def enter_in_shape(self, in_shape)
-
Gets called immediately after initialization. Overwrite if internal parameters depend on the input shape.
Expand source code
def enter_in_shape(self, in_shape): """Gets called immediately after initialization. Overwrite if internal parameters depend on the input shape. """ pass
def forward(self, input, state) ‑> Callable[..., Any]
-
Must be implemented with the following signature.
Args
input
- the input tensor, the expected shape needs to be reported when registering the module, see
register_recurrent()
state
- some state as any combination of dicts, lists, tuples and tensors (needs to have format consistent
with
get_initial_state
). The empty tuple is used to indicate no state.
Returns
- (output, new_state) where
output
- the output tensor, format needs to match input but data dimensions can be different
new_state
- the new state in the same format as the input state
Expand source code
@_abstractmethod def forward(self, input, state): """Must be implemented with the following signature. Args: input: the input tensor, the expected shape needs to be reported when registering the module, see `register_recurrent` state: some state as any combination of dicts, lists, tuples and tensors (needs to have format consistent with `get_initial_state`). The empty tuple is used to indicate no state. Returns: (output, new_state) where output: the output tensor, format needs to match input but data dimensions can be different new_state: the new state in the same format as the input state """ pass
def get_initial_output(self, full_shape)
-
Returns the initial output used for
Placeholder
s. This defaults to zero and can be overwritten by individualPlaceholder
s.Args
full_shape
- shape of the output in the format (batch, data0, data1, …).
Expand source code
def get_initial_output(self, full_shape): """Returns the initial output used for `Placeholder`s. This defaults to zero and can be overwritten by individual `Placeholder`s. Args: full_shape: shape of the output in the format (batch, data0, data1, ...). """ return _torch.zeros(full_shape, device=self.device)
def get_initial_state(self, batch_size)
-
Returns initial state (in the same format as
forward
, for a single step, first dimension being batch size).Expand source code
def get_initial_state(self, batch_size): """Returns initial state (in the same format as `forward`, for a single step, first dimension being batch size). """ return ()
def get_out_shape(self, in_shape)
-
Args
in_shape
- data shape of incoming tensor (excluding time and batch dimensions) before any flattening is applied
In most cases, one of three values should be returned: 1. 'in_shape' if the module does not perform changes to the shape of the data (flattening does not count) 2. some_fixed_shape if the output shape is independent of the input e.g. (out_shape,) for Linear layers 3. 'None' in all other cases (output shape will be inferred automatically)
Expand source code
@_abstractmethod def get_out_shape(self, in_shape): """ Args: in_shape: data shape of incoming tensor (excluding time and batch dimensions) before any flattening is applied In most cases, one of three values should be returned: 1. 'in_shape' if the module does not perform changes to the shape of the data (flattening does not count) 2. some_fixed_shape if the output shape is independent of the input e.g. (out_shape,) for Linear layers 3. 'None' in all other cases (output shape will be inferred automatically) """ pass