Edit on GitHub
PyTorch
DVCLive allows you to add experiment tracking capabilities to your PyTorch projects.
If you are using PyTorch Lightning, check the DVCLive - PyTorch Lightning page.
Usage
You need to create a Live instance and include calls to
log data and
update the step number.
This snippet is used inside the Colab Notebook linked above:
from dvclive import Live
...
with Live(report="notebook") as live:
    live.log_params(params)
    for _ in range(params["epochs"]):
        train_one_epoch(
            model, criterion, x_train, y_train, params["lr"], params["weight_decay"]
        )
        # Train Evaluation
        metrics_train, acual_train, predicted_train = evaluate(
            model, x_train, y_train)
        for k, v in metrics_train.items():
            live.log_metric(f"train/{k}", v)
        live.log_sklearn_plot(
            "confusion_matrix",
            acual_train, predicted_train,
            name="train/confusion_matrix"
        )
        # Test Evaluation
        metrics_test, actual, predicted = evaluate(
            model, x_test, y_test)
        for k, v in metrics_test.items():
            live.log_metric(f"test/{k}", v)
        live.log_sklearn_plot(
            "confusion_matrix", actual, predicted, name="test/confusion_matrix"
        )
        live.log_image(
            "misclassified.jpg",
            get_missclassified_image(actual, predicted, mnist_test)
        )
        # Save best model
        if metrics_test["acc"] > best_test_acc:
            torch.save(model.state_dict(), "model.pt")
        live.next_step()
    live.log_artifact("model.pt", type="model", name="pytorch-model")DistributedDataParallel
If you are using DistributedDataParallel (DDP) to parallelize training over multiple processes, call DVCLive only in the rank 0 process. The Lightning callback will do this automatically. You can also write your own code so that it only calls DVCLive in the rank 0 process:
from dvclive import Live
from torch.distributed import get_rank
...
rank = torch.distributed.get_rank()
if rank == 0:
    # Train model and log with dvclive
    with Live() as live:
        train(...)
        live.log_metric(...)
else:
    # Train model without dvclive
    train(...)