您好,当我下载了seed4数据集时,我按照remade的要求运行代码,但是出现了维度不匹配的问题,
Traceback (most recent call last):
File "f:\CNN\MS-MDA-mai_1\msmdaer.py", line 345, in <module>
data, label = utils.load_data(dataset_name)
File "f:\CNN\MS-MDA-mai_1\utils.py", line 320, in load_data
return np.array(data), np.array(label)
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (3, 15) + inhomogeneous part.
这是出现的错误,具体是在这个函数
def load_data(dataset_name):
path, allmats = get_allmats_name(dataset_name)
data = [([0] * 15) for i in range(3)]
label = [([0] * 15) for i in range(3)]
for i in range(len(allmats)):
for j in range(len(allmats[0])):
mat_path = path + "/" + str(i + 1) + "/" + allmats[i][j]
one_data, one_label = get_data_label_frommat(mat_path, dataset_name, i)
print(one_data.shape, one_label.shape)
data[i][j] = one_data.copy()
label[i][j] = one_label.copy()
return np.array(data), np.array(label)
我把data和lable的形状打印出来后,三个session的输出如下,
(851, 310) (851, 1)
(832, 310) (832, 1)
(822, 310) (822, 1),
这导致数组无法转换为numpy
请问我该如何解决,可以直接截断多余的部分吗
您好,当我下载了seed4数据集时,我按照remade的要求运行代码,但是出现了维度不匹配的问题,
这是出现的错误,具体是在这个函数
我把data和lable的形状打印出来后,三个session的输出如下,
(851, 310) (851, 1)
(832, 310) (832, 1)
(822, 310) (822, 1),
这导致数组无法转换为numpy
请问我该如何解决,可以直接截断多余的部分吗