Skip to content

Chronos: Support TF2 in Chronos Forecaster #4023

Open
@TheaperDeng

Description

@TheaperDeng

Import

Here we take LSTMForecaster as an example, same things will happen on LSTMForecaster, TCNForecaster, and Seq2SeqForecaster in the first round.

# which backend will be used by this will depend on a sys env CHRONOS_FORECASTER_BACKEND (default to 'torch')
# this is to support our legacy code
from bigdl.chronos.forecaster import LSTMForecaster

# These two will use the backend name stated in import
# tf2 is the same name as orca.learn
from bigdl.chronos.forecaster.pytorch import LSTMForecaster
from bigdl.chronos.forecaster.tf2 import LSTMForecater

Interface

I had a quick review of all APIs we have for pytorch and don't find any of them can not be supported in TF2. With the help of bigdl.nano.tf and bigdl.orca.learn.tf2. We should be able to support exactly the same API as pytorch.

class BaseTF2Forecaster(Forecaster):
    def __init__(..., optimizer="Adam",
                 loss="mse",
                 lr=0.001,
                 metrics=["mse"],
                 seed=None,
                 distributed=False,
                 workers_per_node=1,
                 distributed_backend="tf2"):
        ```
        no difference between pytorch
        distributed_backend will supporthorovod”, “tf2orspark```
    def fit(self, data, epochs=1, batch_size=32):
        ```
        no difference between pytorch 
        for data we will support
        1. a numpy ndarray tuple (x, y)
        2. a xshard item
        3. a tf.data.Dataset
        ```
    def predict(self, data, batch_size=32, quantize=False):
        pass
    def predict_with_onnx(self, data, batch_size=32):
        pass
    def evaluate(self, data, batch_size=32, multioutput="raw_values", quantize=False):
        pass
    def evaluate_with_onnx(self, data, batch_size=32, multioutput="raw_values"):
        pass
    def save(self, checkpoint_file, quantize_checkpoint_file=None):
        pass
    def load(self, checkpoint_file, quantize_checkpoint_file=None):
        pass
    def to_local(self):
        pass
    def get_model(self):
        pass
    def build_onnx(self, thread_num=None, sess_options=None):
        pass
    def export_onnx_file(self, dirname="model.onnx"):
        pass
    def quantize(self, calib_data,
                 val_data=None,
                 metric=None,
                 conf=None,
                 framework='tensorflow',
                 approach='static',
                 tuning_strategy='bayesian',
                 relative_drop=None,
                 absolute_drop=None,
                 timeout=0,
                 max_trials=1):
        pass

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions