tl;dr
import pandas as pd
import mlflow
def get_metric_history(run_id, metric):
client = mlflow.tracking.MlflowClient()
history = client.get_metric_history(run_id, metric)
history = [dict(key=m.key, value=m.value, timestamp=m.timestamp, step=m.step) for m in history]
history = pd.DataFrame(history).sort_values("step")
history.timestamp = pd.to_datetime(history.timestamp, unit="ms")
return history
train_loss = get_metric_history(run_id, "train_loss")
valid_loss = get_metric_history(run_id, "valid_loss")
history = pd.concat((train_loss, valid_loss))
history.pivot(index="step", columns="key", values="value").plot()
How to If you want to interact with the server via MLFlow, use mlflow.tracking.MlflowClient.
You can use MlflowClient.get_metric_history (run_id, key)
to get all the metric history of a key in a run.
Here is a function that uses this to get the metric history as pandas.DataFrame
.
Timestamp is easier to handle if converted to datetime.
def get_metric_history(run_id, metric):
client = mlflow.tracking.MlflowClient()
history = client.get_metric_history(run_id, metric)
history = [dict(key=m.key, value=m.value, timestamp=m.timestamp, step=m.step) for m in history]
history = pd.DataFrame(history).sort_values("step")
history.timestamp = pd.to_datetime(history.timestamp, unit="ms")
return history
train_loss = get_metric_history(run_id, "train_loss")
As a side note, if you want to compare multiple metric and plot, you can easily do it by using pivot after combining the Dataframes vertically as shown below.
train_loss = get_metric_history(run_id, "train_loss")
valid_loss = get_metric_history(run_id, "valid_loss")
history = pd.concat((train_loss, valid_loss))
history.pivot(index="step", columns="key", values="value").plot()
Recommended Posts