irl_benchmark.utils package¶
Submodules¶
irl_benchmark.utils.general module¶
Utils module containing general helper functions.
-
irl_benchmark.utils.general.
to_one_hot
(hot_vals: Union[int, List[int], <sphinx.ext.autodoc.importer._MockObject object at 0x10f565eb8>], max_val: int, zeros_function: Callable = <sphinx.ext.autodoc.importer._MockObject object>) → Union[<sphinx.ext.autodoc.importer._MockObject object at 0x10f50f358>, <sphinx.ext.autodoc.importer._MockObject object at 0x10f555cc0>]¶ Convert an integer or a list of integers to a one-hot array.
- hot_vals: Union[int, List[int], np.ndarray]
- A single integer, or a list / vector of integers, corresponding to the hot values which will equal one in the returned array.
- max_val: int
- The maximum possible value in hot_values. All elements in hot_vals have to be smaller than max_val (since we start counting at 0).
- zeros_function: Callable
- Controls which function is used to create the array. It should be either numpy.zeros or torch.zeros.
- Union[np.ndarray, torch.tensor]
- Either a numpy array or torch tensor with the one-hot encoded values. Type of returned data structure depends on the passed zeros_function. The default is numpy array. The returned data structure will be of shape (1, max_value) if hot_vals is a single integer, and (len(hot_vals), max_value) otherwise.
irl_benchmark.utils.irl module¶
irl_benchmark.utils.rl module¶
Utils related to reinforcement learning.
-
irl_benchmark.utils.rl.
true_reward_per_traj
(trajs: List[Dict[str, list]]) → float¶ Return (undiscounted) average sum of true rewards per trajectory.
- trajs: List[Dict[str, list]])
- A list of trajectories.
Each trajectory is a dictionary with keys
[‘states’, ‘actions’, ‘rewards’, ‘true_rewards’, ‘features’].
The values of each dictionary are lists.
See
irl_benchmark.irl.collect.collect_trajs()
.
- float
- The undiscounted average sum of true rewards per trajectory.
irl_benchmark.utils.wrapper module¶
Utils module containing wrapper specific helper functions.
-
irl_benchmark.utils.wrapper.
is_unwrappable_to
(env: <sphinx.ext.autodoc.importer._MockObject object at 0x10f46e9e8>, to_wrapper: Type[<sphinx.ext.autodoc.importer._MockObject object at 0x10f46e710>]) → bool¶ Check if env can be unwrapped to to_wrapper.
- env: gym.Env
- A gym environment (potentially wrapped).
- to_wrapper: Type[gym.Wrapper]
- A wrapper class extending gym.Wrapper.
- bool
- True if env could be unwrapped to desired wrapper, False otherwise.
-
irl_benchmark.utils.wrapper.
unwrap_env
(env: <sphinx.ext.autodoc.importer._MockObject object at 0x10f50f6a0>, until_class: Union[None, <sphinx.ext.autodoc.importer._MockObject object at 0x10f50ff28>] = None) → <sphinx.ext.autodoc.importer._MockObject object at 0x10f50f400>¶ Unwrap wrapped env until we get an instance that is a until_class.
If until_class is None, env will be unwrapped until the lowest layer.
Module contents¶
A module with useful functions for the irl_benchmark framework.