diff --git a/pixplot/pixplot.py b/pixplot/pixplot.py index d512b7cd..8c2103e4 100644 --- a/pixplot/pixplot.py +++ b/pixplot/pixplot.py @@ -580,13 +580,20 @@ def get_inception_vectors(**kwargs): vector_path = os.path.join(vector_dir, clean_filename(i.path) + '.npy') if os.path.exists(vector_path) and kwargs['use_cache']: vec = np.load(vector_path) + if len(vec) < 2: + vec = np.expand_dims(vec, 0) else: im = preprocess_input( img_to_array( i.original.resize((299,299)) ) ) vec = model.predict(np.expand_dims(im, 0)).squeeze() np.save(vector_path, vec) vecs.append(vec) progress_bar.update(1) - return np.array(vecs) + + if os.path.exists(vector_path) and kwargs['use_cache']: + vecs = np.stack(vecs, 0).squeeze() + return vecs + else: + return np.array(vecs) def get_umap_layout(**kwargs):