Open
Description
Import
Here we take LSTMForecaster
as an example, same things will happen on LSTMForecaster
, TCNForecaster
, and Seq2SeqForecaster
in the first round.
# which backend will be used by this will depend on a sys env CHRONOS_FORECASTER_BACKEND (default to 'torch')
# this is to support our legacy code
from bigdl.chronos.forecaster import LSTMForecaster
# These two will use the backend name stated in import
# tf2 is the same name as orca.learn
from bigdl.chronos.forecaster.pytorch import LSTMForecaster
from bigdl.chronos.forecaster.tf2 import LSTMForecater
Interface
I had a quick review of all APIs we have for pytorch and don't find any of them can not be supported in TF2. With the help of bigdl.nano.tf
and bigdl.orca.learn.tf2
. We should be able to support exactly the same API as pytorch.
class BaseTF2Forecaster(Forecaster):
def __init__(..., optimizer="Adam",
loss="mse",
lr=0.001,
metrics=["mse"],
seed=None,
distributed=False,
workers_per_node=1,
distributed_backend="tf2"):
```
no difference between pytorch
distributed_backend will support “horovod”, “tf2” or “spark”
```
def fit(self, data, epochs=1, batch_size=32):
```
no difference between pytorch
for data we will support
1. a numpy ndarray tuple (x, y)
2. a xshard item
3. a tf.data.Dataset
```
def predict(self, data, batch_size=32, quantize=False):
pass
def predict_with_onnx(self, data, batch_size=32):
pass
def evaluate(self, data, batch_size=32, multioutput="raw_values", quantize=False):
pass
def evaluate_with_onnx(self, data, batch_size=32, multioutput="raw_values"):
pass
def save(self, checkpoint_file, quantize_checkpoint_file=None):
pass
def load(self, checkpoint_file, quantize_checkpoint_file=None):
pass
def to_local(self):
pass
def get_model(self):
pass
def build_onnx(self, thread_num=None, sess_options=None):
pass
def export_onnx_file(self, dirname="model.onnx"):
pass
def quantize(self, calib_data,
val_data=None,
metric=None,
conf=None,
framework='tensorflow',
approach='static',
tuning_strategy='bayesian',
relative_drop=None,
absolute_drop=None,
timeout=0,
max_trials=1):
pass