Edit on GitHub

Live.log_sklearn_plot()

Generates a scikit learn plot and saves the data in {Live.dir}/plots/sklearn/{name}.json.

def log_sklearn_plot(
  self,
  kind: Literal['calibration', 'confusion_matrix', 'precision_recall', 'roc'],
  labels,
  predictions,
  name: Optional[str] = None,
  **kwargs):

Usage

from dvclive import Live

with Live() as live:
  y_true = [0, 0, 1, 1]
  y_pred = [1, 0, 1, 0]
  y_score = [0.1, 0.4, 0.35, 0.8]
  live.log_sklearn_plot("roc", y_true, y_score)
  live.log_sklearn_plot(
    "confusion_matrix", y_true, y_pred, name="cm.json")

Description

The method will compute and dump the kind plot (see supported plots) to {Live.dir}/plots/sklearn/{name} in a format compatible with dvc plots.

It will also store the provided properties to be included in the plots section written by Live.make_dvcyaml(). The example snippet would produce the following dvc.yaml in {Live.dir}/{Live.dvc_file}:

plots:
  - plots/sklearn/roc.json:
      template: simple
      x: fpr
      y: tpr
      title: Receiver operating characteristic (ROC)
      x_label: False Positive Rate
      y_label: True Positive Rate
  - plots/sklearn/cm.json:
      template: confusion
      x: actual
      y: predicted
      title: Confusion Matrix
      x_label: True Label
      y_label: Predicted Label

Supported plots

kind must be one of the supported plots:

Generates a calibration curve plot.

y_true = [0, 0, 1, 1]
y_score = [0.1, 0.4, 0.35, 0.8]
live.log_sklearn_plot("calibration", y_true, y_score)

dvclive calibration

Generates a confusion matrix plot.

y_true = [1, 1, 2, 2]
y_pred = [2, 1, 1, 2]
live.log_sklearn_plot("confusion_matrix", y_true, y_pred)

dvclive confusion matrix

Generates a detection error tradeoff (DET) plot.

y_true = [1, 1, 2, 2]
y_score = [0.1, 0.4, 0.35, 0.8]
live.log_sklearn_plot("det", y_true, y_score)

dvclive det

Generates a precision-recall curve plot.

y_true = [1, 1, 2, 2]
y_score = [0.1, 0.4, 0.35, 0.8]
live.log_sklearn_plot("precision_recall", y_true, y_score)

dvclive precision recall

Generates a receiver operating characteristic (ROC) curve plot.

y_true = [1, 1, 2, 2]
y_score = [0.1, 0.4, 0.35, 0.8]
live.log_sklearn_plot("roc", y_true, y_score)

dvclive roc

Parameters

  • kind - a supported plot type.

  • labels - array of ground truth labels.

  • predictions - array of predicted labels (for confusion_matrix) or predicted probabilities (for other plots).

  • name - optional name of the output file. If not provided, kind will be used as name.

  • **kwargs - additional arguments to tune the result. Arguments are passed to the scikit-learn function (e.g. drop_intermediate=True for the roc type). Plus extra arguments supported by the type of a plot are:

    • normalized - default: False. confusion_matrix with values normalized to <0, 1> range.

Exceptions

  • dvclive.error.InvalidPlotTypeError - thrown if the provided kind does not correspond to any of the supported plots.
Content

🐛 Found an issue? Let us know! Or fix it:

Edit on GitHub

Have a question? Join our chat, we will help you:

Discord Chat