[BUG] Robust indexing & safe dtype handling in tsai’s TSDataLoaders / TfmdLists #947
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes sktime/sktime#7885 <[ENH] interface to tsai package>
What does this implement / fix?
While wiring tsai into sktime I hit two subtle indexing / dtype issues that can also bite vanilla tsai users:
----NumPy scalar vs. Python int in TfmdLists + Subset
When self._splits is a NumPy array (dtype=int8/int32/…) the value returned by
idx = self._splits[it] is a NumPy scalar.
Indexing self.items[idx] with that object fails, breaking Learner.predict, dev
notebooks, and downstream libs.
Patch: Convert NumPy scalar → .item() and tiny NumPy array → .tolist() before the final lookup.
Zero behavioural change for normal Python ints/lists.
-----Unsafe dtype when casting NumPy arrays to tensors in TfmdLists.init
Inside the in-place branch (inplace=True, tfms=None) we call typ(tl.items).
If tl.items is an integer array, torch.as_tensor produces a LongTensor, which later collides with models expecting FloatTensors.
Patch: After the existing cast, explicitly re-cast any NumPy array to torch.float32.
Dependency impact
None—pure Python changes, no new packages.
Focus for reviewers
Sanity-check the scalar/array coercion logic in getitem.
Confirm that forcing float32 won’t interfere with edge-cases
where integer tensors are explicitly desired (I could not find any).
I’ve run the basic_motions smoke-tests plus a small custom dataset; everything trains and predicts fine on CPU and GPU.
Thanks for taking a look!