Torch ORT CallbackΒΆ
Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. See installation instructions here.
This is primarily useful for when training with a Transformer model. The ORT callback works when a single model is specified as self.model within the LightningModule
as shown below.
Note
Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.
from pytorch_lightning import LightningModule, Trainer
from transformers import AutoModel
from pl_bolts.callbacks import ORTCallback
class MyTransformerModel(LightningModule):
def __init__(self):
super().__init__()
self.model = AutoModel.from_pretrained('bert-base-cased')
...
model = MyTransformerModel()
trainer = Trainer(gpus=1, callbacks=ORTCallback())
trainer.fit(model)
For even easier setup and integration, have a look at our Lightning Flash integration for Text Classification, Translation and Summarization.