Edit on GitHub

PyTorch Lightning

DVCLive allows you to add experiment tracking capabilities to your PyTorch Lightning projects.

Usage

Pass the DVCLiveLogger to your Trainer:

from dvclive.lightning import DVCLiveLogger

...
dvclive_logger = DVCLiveLogger()

trainer = Trainer(logger=dvclive_logger)
trainer.fit(model)

Each metric will be logged to:

{Live.plots_dir}/metrics/{split}/{iter_type}/{metric}.tsv

Where:

  • {Live.plots_dir} is defined in Live.
  • {split} can be either train or eval.
  • {iter_type} can be either epoch or step.
  • {metric} is the name provided by the framework.

Parameters

  • run_name - (None by default) - Name of the run, used in PyTorch Lightning to get version.

  • prefix - (None by default) - string that adds to each metric name.

  • experiment - (None by default) - Live object to be used instead of initializing a new one.

  • **kwargs - Any additional arguments will be used to instantiate a new Live instance. If experiment is used, the arguments are ignored.

Examples

  • Using experiment to pass an existing Live instance.
from dvclive import Live
from dvclive.lightning import DVCLiveLogger

with Live("custom_dir") as live:
    trainer = Trainer(
        logger=DVCLiveLogger(experiment=live))
    trainer.fit(model)
    # Log additional metrics after training
    live.log_metric("summary_metric", 1.0, plot=False)
  • Using **kwargs to customize Live.
from dvclive.lightning import DVCLiveLogger

trainer = Trainer(
    logger=DVCLiveLogger(dir='my_logs_dir'))
trainer.fit(model)
with Live() as live:
    checkpoint = ModelCheckpoint(dirpath="mymodel")
    trainer = Trainer(
        logger=DVCLiveLogger(experiment=live),
        callbacks=checkpoint
    )
    trainer.fit(model)
    live.log_artifact(
        checkpoint.best_model_path,
        type="model",
        name="lightning-model"
    )
class LitModule(LightningModule):
    def __init__(self, layer_1_dim, learning_rate):
        super().__init__()
        # call this to save (layer_1_dim=128, learning_rate=1e-4)
        self.save_hyperparameters()

model = LitModule(layer_1_dim=128, learning_rate=1e-4)
trainer = Trainer(logger=DVCLiveLogger())
trainer.fit(model)

By default, PyTorch Lightning creates a directory to store checkpoints using the logger's name (DVCLiveLogger). You can change the checkpoint path or disable checkpointing at all as described in the PyTorch Lightning documentation

Content

🐛 Found an issue? Let us know! Or fix it:

Edit on GitHub

❓ Have a question? Join our chat, we will help you:

Discord Chat