Skip to content

Nano: quantization API with onnx renew design #3806

Open
@TheaperDeng

Description

@TheaperDeng

Current status (without onnx quantization)

For pytorch-lightning user:

pl_model = Trainer.compile(model, loss, optim)  # skip if you have a pl model
trainer.fit(pl_model, dataloader)
pl_model_quantized = trainer.quantize(pl_model, dataloader)
pl_model_quantized(x)  # quantized inference

For pytorch user (potentially, we do not support it yet but it can be valid simply):

# >>>>>>>> start of pytorch training loop >>>>>>>>>>>
# ...
# <<<<<<<< end  of pytorch training loop <<<<<<<<<<<<
model_quantized = trainer.quantize(model, dataloader)
model_quantized(x)  # quantized inference

Issue

  1. pl_model_quantized and model_quantized are torch.fx.graph_module.GraphModule. An unfamilar type to users. Users can not:
  • use onnx.export to trace the quantized model.

  • use normal way to save or load the model.

  • continually train on this quantized model

  1. It is extremly hard to integrate onnx quantization with INC to this API

New revised API usage (with onnx quantization)

For pytorch-lightning user:

pl_model = Trainer.compile(model, loss, optim, onnx=True/False)  # skip if you have a pl model
trainer.fit(pl_model, dataloader)
pl_model = trainer.quantize(pl_model, dataloader, onnx=True/False)

For pytorch user:

model = Trainer.compile(model, onnx=True/False)
# >>>>>>>> start of pytorch training loop >>>>>>>>>>>
# ...
# <<<<<<<< end  of pytorch training loop <<<<<<<<<<<<
model = trainer.quantize(model, dataloader, onnx=True/False)

pl_model and model are still pytorch-lightning model, then the prediction can be

# predict with pytorch fp32
pl_model.eval()
with torch.no_grad():
    pl_model(x)
# or
pl_model.inference(x, backend=None)

# predict with pytorch int8
pl_model.eval(quantize=True)
with torch.no_grad():
    pl_model(x)
# or
pl_model.inference(x, backend=None, quantize=True)

# predict with onnx fp32
pl_model.eval_onnx()
with torch.no_grad():
    pl_model(x)
# or
pl_model.inference(x, backend="onnx")

# predict with onnx int8
pl_model.eval_onnx(quantize=True)
with torch.no_grad():
    pl_model(x)
# or
pl_model.inference(x, backend="onnx", quantize=True)

We should also provide an option to return the "raw" result to users, where onnx quantization will return a onnx model( onnx.onnx_ml_pb2.ModelProto ) and pytorch fx quantization will return a fx model( torch.fx.GraphModule ).

model = trainer.quantize(..., raw_return=True)  # defaultly False

So basically three PRs will be raised separately for this issue:

  • change the Trainer.quantize(..., raw_return=False) to return a pl model and support the easy inference api (.eval(quantized=True)) Nano quantize inference API for pytorch #3866
  • add support to onnx's quantization
  • add support of save and load

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions