Source code for rindti.utils.vis
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
[docs]def plot_loss_count_dist(losses: dict) -> Figure:
"""Plot distribution of times sampled vs avg loss of families."""
fig = plt.figure()
plt.xlabel("Times sampled")
plt.ylabel("Avg loss")
plt.title("Prot statistics")
count = [len(x) for x in losses.values()]
mean = [np.mean(x) for x in losses.values()]
plt.scatter(x=count, y=mean)
return fig