Edit on GitHub

PyTorch Lightning

DVCLive allows you to add experiment tracking capabilities to your PyTorch Lightning projects.

If you are using Lightning Fabric, check the DVCLive - Lightning Fabric page.

Usage

If you pass the DVCLiveLogger to your Trainer, DVCLive will automatically log the metrics and parameters tracked in your LightningModule.

import lightning.pytorch as pl
from dvclive.lightning import DVCLiveLogger

...
class LitModule(pl.LightningModule):
    def __init__(self, layer_1_dim=128, learning_rate=1e-2):
        super().__init__()
        # layer_1_dim and learning_rate will be logged by DVCLive
        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        metric = ...
        # See Output Format bellow
        self.log("train_metric", metric, on_step=False, on_epoch=True)

dvclive_logger = DVCLiveLogger()

model = LitModule()
trainer = pl.Trainer(logger=dvclive_logger)
trainer.fit(model)

By default, PyTorch Lightning creates a directory to store checkpoints using the logger's name (DVCLiveLogger). You can change the checkpoint path or disable checkpointing at all as described in the PyTorch Lightning documentation

Parameters

  • run_name - (None by default) - Name of the run, used in PyTorch Lightning to get version.

  • prefix - (None by default) - string that adds to each metric name.

  • log_model - (False by default) - use live.log_artifact() to log checkpoints created by ModelCheckpoint. See Log model checkpoints.

    • if log_model == False (default), no checkpoint is logged.

    • if log_model == True, checkpoints are logged at the end of training, except when save_top_k == -1 which logs every checkpoint during training.

    • if log_model == 'all', checkpoints are logged during training.

  • experiment - (None by default) - Live object to be used instead of initializing a new one.

  • **kwargs - Any additional arguments will be used to instantiate a new Live instance. If experiment is used, the arguments are ignored.

Examples

Log model checkpoints

Use log_model to save the checkpoints (it will use Live.log_artifact() internally to save those). At the end of training, DVCLive will copy the best_model_path to the dvclive/artifacts directory and annotate it with name best (for example, to be consumed in [DVC Studio model registry] or automation scenarios).

  • Save updates to the checkpoints directory at the end of training:
from dvclive.lightning import DVCLiveLogger

logger = DVCLiveLogger(log_model=True)
trainer = Trainer(logger=logger)
trainer.fit(model)
  • Save updates to the checkpoints directory whenever a new checkpoint is saved:
from dvclive.lightning import DVCLiveLogger

logger = DVCLiveLogger(log_model="all")
trainer = Trainer(logger=logger)
trainer.fit(model)
  • Use a custom ModelCheckpoint:
from dvclive.lightning import DVCLiveLogger

logger = DVCLiveLogger(log_model=True),
checkpoint_callback = ModelCheckpoint(
        dirpath="model",
        monitor="val_acc",
        mode="max",
)
trainer = Trainer(logger=logger, callbacks=[checkpoint_callback])
trainer.fit(model)

Passing additional DVCLive arguments

  • Using experiment to pass an existing Live instance.
from dvclive import Live
from dvclive.lightning import DVCLiveLogger

with Live("custom_dir") as live:
    trainer = Trainer(
        logger=DVCLiveLogger(experiment=live))
    trainer.fit(model)
    # Log additional metrics after training
    live.log_metric("summary_metric", 1.0, plot=False)
  • Using **kwargs to customize Live.
from dvclive.lightning import DVCLiveLogger

trainer = Trainer(
    logger=DVCLiveLogger(dir='my_logs_dir'))
trainer.fit(model)

Output format

Each metric will be logged to:

{Live.plots_dir}/metrics/{split_prefix}/{iter_type}/{metric_name}.tsv

Where:

  • {Live.plots_dir} is defined in Live.
  • {iter_type} can be either epoch or step. This is inferred from the on_step and on_epoch arguments used in the log call.
  • {split_prefix}_{metric_name} is the full string passed to the log call. split_prefix can be either train, val or test.

In the example above, the metric logged as:

self.log("train_metric", metric, on_step=False, on_epoch=True)

Will be stored in:

dvclive/metrics/train/epoch/metric.tsv
Content

๐Ÿ› Found an issue? Let us know! Or fix it:

Edit on GitHub

โ“ Have a question? Join our chat, we will help you:

Discord Chat