diff --git a/rlgym/rocket_league/action_parsers/lookup_table_action.py b/rlgym/rocket_league/action_parsers/lookup_table_action.py index 8e2e2a5..c93d0b4 100644 --- a/rlgym/rocket_league/action_parsers/lookup_table_action.py +++ b/rlgym/rocket_league/action_parsers/lookup_table_action.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, Tuple, List +from typing import Dict, Any, Tuple, List, Union import numpy as np @@ -6,7 +6,7 @@ from rlgym.rocket_league.api import GameState -class LookupTableAction(ActionParser[AgentID, np.ndarray, np.ndarray, GameState, Tuple[str, int]]): +class LookupTableAction(ActionParser[AgentID, Union[np.ndarray, int], np.ndarray, GameState, Tuple[str, int]]): """ World-famous discrete action parser which uses a lookup table to reduce the number of possible actions from 1944 to 90 """ @@ -21,17 +21,12 @@ def get_action_space(self, agent: AgentID) -> Tuple[str, int]: def reset(self, agents: List[AgentID], initial_state: GameState, shared_info: Dict[str, Any]) -> None: pass - def parse_actions(self, actions: Dict[AgentID, np.ndarray], state: GameState, shared_info: Dict[str, Any]) -> Dict[AgentID, np.ndarray]: - parsed_actions = {} - for agent, action in actions.items(): - # Action can have shape (Ticks, 1) or (Ticks) - assert len(action.shape) == 1 or (len(action.shape) == 2 and action.shape[1] == 1) - - if len(action.shape) == 2: - action = action.squeeze(1) - - parsed_actions[agent] = self._lookup_table[action] - + def parse_actions(self, actions: Dict[AgentID, Union[np.ndarray, int]], state: GameState, shared_info: Dict[str, Any]) -> Dict[AgentID, np.ndarray]: + parsed_actions = { + agent: self._lookup_table[action] + for agent, action in actions.items() + } + return parsed_actions @staticmethod