PyTorch Lightning
DVCLive allows you to add experiment tracking capabilities to your PyTorch Lightning projects.
Usage
Pass the
DVCLiveLogger
to your
Trainer
:
from dvclive.lightning import DVCLiveLogger
...
dvclive_logger = DVCLiveLogger()
trainer = Trainer(logger=dvclive_logger)
trainer.fit(model)
Each metric will be logged to:
{Live.plots_dir}/metrics/{split}/{iter_type}/{metric}.tsv
Where:
{Live.plots_dir}
is defined inLive
.{split}
can be eithertrain
oreval
.{iter_type}
can be eitherepoch
orstep
.{metric}
is the name provided by the framework.
Parameters
-
run_name
- (None
by default) - Name of the run, used in PyTorch Lightning to get version. -
prefix
- (None
by default) - string that adds to each metric name. -
experiment
- (None
by default) -Live
object to be used instead of initializing a new one. -
**kwargs
- Any additional arguments will be used to instantiate a newLive
instance. Ifexperiment
is used, the arguments are ignored.
Examples
- Using
experiment
to pass an existingLive
instance.
from dvclive import Live
from dvclive.lightning import DVCLiveLogger
with Live("custom_dir") as live:
trainer = Trainer(
logger=DVCLiveLogger(experiment=live))
trainer.fit(model)
# Log additional metrics after training
live.log_metric("summary_metric", 1.0, plot=False)
- Using
**kwargs
to customizeLive
.
from dvclive.lightning import DVCLiveLogger
trainer = Trainer(
logger=DVCLiveLogger(dir='my_logs_dir'))
trainer.fit(model)
- Using
live.log_artifact()
to save the best checkpoint.
with Live() as live:
checkpoint = ModelCheckpoint(dirpath="mymodel")
trainer = Trainer(
logger=DVCLiveLogger(experiment=live),
callbacks=checkpoint
)
trainer.fit(model)
live.log_artifact(
checkpoint.best_model_path,
type="model",
name="lightning-model"
)
- Logging hyperparameters.
class LitModule(LightningModule):
def __init__(self, layer_1_dim, learning_rate):
super().__init__()
# call this to save (layer_1_dim=128, learning_rate=1e-4)
self.save_hyperparameters()
model = LitModule(layer_1_dim=128, learning_rate=1e-4)
trainer = Trainer(logger=DVCLiveLogger())
trainer.fit(model)
By default, PyTorch Lightning creates a directory to store checkpoints using the
logger's name (DVCLiveLogger
). You can change the checkpoint path or disable
checkpointing at all as described in the
PyTorch Lightning documentation