Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions predicators/envs/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
class BlocksEnv(BaseEnv):
"""Blocks domain."""
# Parameters that aren't important enough to need to clog up settings.py
table_height: ClassVar[float] = 0.2
table_height: ClassVar[float] = 0.4
# The table x bounds are (1.1, 1.6), but the workspace is smaller.
# Make it narrow enough that blocks can be only horizontally arranged.
# Note that these boundaries are for the block positions, and that a
Expand All @@ -48,6 +48,8 @@ class BlocksEnv(BaseEnv):
pick_tol: ClassVar[float] = 0.0001
on_tol: ClassVar[float] = 0.01
collision_padding: ClassVar[float] = 2.0
open_fingers: ClassVar[float] = 0.04
closed_fingers: ClassVar[float] = 0.01

def __init__(self, use_gui: bool = True) -> None:
super().__init__(use_gui)
Expand All @@ -71,6 +73,8 @@ def __init__(self, use_gui: bool = True) -> None:
self._Clear = Predicate("Clear", [self._block_type], self._Clear_holds)
# Static objects (always exist no matter the settings).
self._robot = Object("robby", self._robot_type)
self._blocks: List[Object] = []
self._create_blocks()
# Hyperparameters from CFG.
self._block_size = CFG.blocks_block_size
self._num_blocks_train = CFG.blocks_num_blocks_train
Expand Down Expand Up @@ -107,7 +111,8 @@ def _transition_pick(self, state: State, x: float, y: float,
next_state.set(block, "pose_y", y)
next_state.set(block, "pose_z", self.pick_z)
next_state.set(block, "held", 1.0)
next_state.set(self._robot, "fingers", 0.0) # close fingers
next_state.set(self._robot, "fingers",
self.closed_fingers) # close fingers
if "clear" in self._block_type.feature_names:
# See BlocksEnvClear
next_state.set(block, "clear", 0)
Expand Down Expand Up @@ -139,7 +144,8 @@ def _transition_putontable(self, state: State, x: float, y: float,
next_state.set(block, "pose_y", y)
next_state.set(block, "pose_z", z)
next_state.set(block, "held", 0.0)
next_state.set(self._robot, "fingers", 1.0) # open fingers
next_state.set(self._robot, "fingers",
self.open_fingers) # open fingers
if "clear" in self._block_type.feature_names:
# See BlocksEnvClear
next_state.set(block, "clear", 1)
Expand Down Expand Up @@ -171,7 +177,8 @@ def _transition_stack(self, state: State, x: float, y: float,
next_state.set(block, "pose_y", cur_y)
next_state.set(block, "pose_z", cur_z + self._block_size)
next_state.set(block, "held", 0.0)
next_state.set(self._robot, "fingers", 1.0) # open fingers
next_state.set(self._robot, "fingers",
self.open_fingers) # open fingers
if "clear" in self._block_type.feature_names:
# See BlocksEnvClear
next_state.set(block, "clear", 1)
Expand Down Expand Up @@ -295,11 +302,19 @@ def _get_tasks(self, num_tasks: int, possible_num_blocks: List[int],
tasks.append(EnvironmentTask(init_state, goal))
return tasks

def _create_blocks(self) -> None:
for i in range(
max(max(CFG.blocks_num_blocks_train),
max(CFG.blocks_num_blocks_test))):
block = Object(f"block{i}", self._block_type)
self._blocks.append(block)

def _sample_initial_piles(self, num_blocks: int,
rng: np.random.Generator) -> List[List[Object]]:
piles: List[List[Object]] = []
for block_num in range(num_blocks):
block = Object(f"block{block_num}", self._block_type)
block = self._blocks[block_num]
# block = Object(f"block{block_num}", self._block_type)
# If coin flip, start new pile
if block_num == 0 or rng.uniform() < 0.2:
piles.append([])
Expand Down Expand Up @@ -340,7 +355,7 @@ def _sample_state_from_piles(self, piles: List[List[Object]],
# Note: the robot poses are not used in this environment (they are
# constant), but they change and get used in the PyBullet subclass.
rx, ry, rz = self.robot_init_x, self.robot_init_y, self.robot_init_z
rf = 1.0 # fingers start out open
rf = self.open_fingers # fingers start out open
data[self._robot] = np.array([rx, ry, rz, rf], dtype=np.float32)
return State(data)

Expand Down Expand Up @@ -407,6 +422,23 @@ def _On_holds(self, state: State, objects: Sequence[Object]) -> bool:
return np.allclose([x1, y1, z1], [x2, y2, z2 + self._block_size],
atol=self.on_tol)

def _count_block_height(self, state: State, block: Object) -> int:
"""Count the height of the block (number of blocks it's on)."""
height = 0
current_block = block
blocks = state.get_objects(self._block_type)

while True:
below_blocks = [
b for b in blocks if self._On_holds(state, [current_block, b])
]
if not below_blocks:
break
current_block = below_blocks[0]
height += 1

return height

def _OnTable_holds(self, state: State, objects: Sequence[Object]) -> bool:
block, = objects
z = state.get(block, "pose_z")
Expand All @@ -418,12 +450,15 @@ def _OnTable_holds(self, state: State, objects: Sequence[Object]) -> bool:
def _GripperOpen_holds(state: State, objects: Sequence[Object]) -> bool:
robot, = objects
rf = state.get(robot, "fingers")
assert rf in (0.0, 1.0)
return rf == 1.0
assert rf in (BlocksEnv.closed_fingers, BlocksEnv.open_fingers)
return rf == BlocksEnv.open_fingers

def _Holding_holds(self, state: State, objects: Sequence[Object]) -> bool:
block, = objects
return self._get_held_block(state) == block
held_block = self._get_held_block(state)
if held_block is None:
return False
return held_block == block

def _Clear_holds(self, state: State, objects: Sequence[Object]) -> bool:
if self._Holding_holds(state, objects):
Expand Down Expand Up @@ -510,7 +545,7 @@ def _load_task_from_json(self, json_file: Path) -> EnvironmentTask:
}
# Add the robot at a constant initial position.
rx, ry, rz = self.robot_init_x, self.robot_init_y, self.robot_init_z
rf = 1.0 # fingers start out open
rf = self.open_fingers # fingers start out open
state_dict[self._robot] = {
"pose_x": rx,
"pose_y": ry,
Expand Down
4 changes: 3 additions & 1 deletion predicators/envs/cluttered_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ def _check_collisions(cls,
colliding_can = None
colliding_can_max_dist = float("-inf")
for can in state:
if can == ignored_can or not cls._Untrashed_holds(state, [can]):
if ignored_can is not None and can == ignored_can or \
not cls._Untrashed_holds(
state, [can]):
continue
this_x = state.get(can, "pose_x")
this_y = state.get(can, "pose_y")
Expand Down
Loading
Loading