Skip to content

Commit 9e80dfd

Browse files
committed
Speeds up the data frame interpolation function.
The gait/interpolate function is only used in the GaitData.split_at() method and the way I orginally coded it using Pandas paradigms was very slow. This change speeds things up by about 40X. I had to change the test because it no longer supports nans in the DataFrame. Being that it is only used in the split_at method this seemed ok. So far I've never had nans in the data frame by the time it gets to the split at method. GaitData isn't built to handle nans, they should be fixed before creating a GaitData object.
1 parent f6831f5 commit 9e80dfd

File tree

2 files changed

+16
-22
lines changed

2 files changed

+16
-22
lines changed

gaitanalysis/gait.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# external libraries
99
import numpy as np
1010
from scipy.integrate import simps
11+
from scipy.interpolate import interp1d
1112
from scipy.signal import firwin, filtfilt
1213
import matplotlib.pyplot as plt
1314
import pandas
@@ -80,7 +81,7 @@ def find_constant_speed(time, speed, plot=False, filter_cutoff=1.0):
8081

8182

8283
def interpolate(data_frame, time):
83-
"""Returns a data frame with a index based on the provided time
84+
"""Returns a new data frame with a index based on the provided time
8485
array and linear interpolation.
8586
8687
Parameters
@@ -100,21 +101,14 @@ def interpolate(data_frame, time):
100101
101102
"""
102103

103-
total_index = np.sort(np.hstack((data_frame.index.values, time)))
104-
reindexed_data_frame = data_frame.reindex(total_index)
105-
interpolated_data_frame = \
106-
reindexed_data_frame.apply(pandas.Series.interpolate,
107-
method='values').loc[time]
108-
109-
# If the first or last value of a series is NA then the interpolate
110-
# function leaves it as an NA value, so use backfill to take care of
111-
# those.
112-
interpolated_data_frame = \
113-
interpolated_data_frame.fillna(method='backfill')
114-
# Because the time vector may have matching indices as the original
115-
# index (i.e. always the zero indice), drop any duplicates so the len()
116-
# stays consistent
117-
return interpolated_data_frame.drop_duplicates()
104+
column_names = data_frame.columns
105+
old_time = data_frame.index.values
106+
vals = data_frame.values
107+
108+
f = interp1d(old_time, vals, axis=0)
109+
new_vals = f(time)
110+
111+
return pandas.DataFrame(new_vals, index=time, columns=column_names)
118112

119113

120114
class GaitData(object):

gaitanalysis/tests/test_gait.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ def test_find_constant_speed():
3434

3535
def test_interpolate():
3636

37-
df = pandas.DataFrame({'a': [np.nan, 3.0, 5.0, 7.0],
38-
'b': [5.0, np.nan, 9.0, 11.0],
37+
df = pandas.DataFrame({'a': [2.0, 3.0, 5.0, 7.0],
38+
'b': [5.0, 8.0, 9.0, 11.0],
3939
'c': [2.0, 4.0, 6.0, 8.0],
40-
'd': [0.5, 1.0, 1.5, np.nan]},
40+
'd': [0.5, 1.0, 1.5, 2.5]},
4141
index=[0.0, 2.0, 4.0, 6.0])
4242

4343
time = [0.0, 1.0, 3.0, 5.0]
@@ -47,10 +47,10 @@ def test_interpolate():
4747
# NOTE : pandas.Series.interpolate does not extrapolate (because
4848
# np.interp doesn't.
4949

50-
df_expected = pandas.DataFrame({'a': [4.0, 4.0, 4.0, 6.0],
51-
'b': [5.0, 6.0, 8.0, 10.0],
50+
df_expected = pandas.DataFrame({'a': [2.0, 2.5, 4.0, 6.0],
51+
'b': [5.0, 6.5, 8.5, 10.0],
5252
'c': [2.0, 3.0, 5.0, 7.0],
53-
'd': [0.5, 0.75, 1.25, 1.5]},
53+
'd': [0.5, 0.75, 1.25, 2.0]},
5454
index=time)
5555

5656
testing.assert_allclose(interpolated.values, df_expected.values)

0 commit comments

Comments
 (0)