Skip to content

axis permutation wrong in utility.pack_into_tensor and utility.unpack_into_tensorarray #8

@JeffOwOSun

Description

@JeffOwOSun
def pack_into_tensor(array, axis):
    """
    packs a given TensorArray into a tensor along a given axis
    Parameters:
    ----------
    array: TensorArray
        the tensor array to pack
    axis: int
        the axis to pack the array along
    Returns: Tensor
        the packed tensor
    """

    packed_tensor = array.pack()
    shape = packed_tensor.get_shape()
    rank = len(shape)

    dim_permutation = [axis] + range(1, axis) + [0]  + range(axis + 1, rank)
    correct_shape_tensor = tf.transpose(packed_tensor, dim_permutation)

    return correct_shape_tensor

Imagine I want to stack a tensorArray with 6 elements of shape [3,4,5] into a tensor of shape [3,4,5,6]. Axis is 3. After array.pack(), the shape of tensor is [6,3,4,5]. dim_permutation is [3] + [1,2] + [0] + [] = [3,1,2,0]. After transpose, the output shape is [5,3,4,6], which is wrong.

The correct formula should be

dim_permutation = range(1, axis+1) + [0] + range(axis+1, rank)

With this formula, the dim_permutation is [1,2,3] + [0] + [] = [1,2,3,0].

Similarly, the formula in unpack_into_tensorarray is also wrong. The correct code should be

dim_permutation = [axis] + range(0, axis) + range(axis+1, rank)

Please take a look. Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions