from typing import Tuple, Union
import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.optim import SGD, Adam, AdamW, RMSprop
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import (
AUROC,
Accuracy,
ExplainedVariance,
MatthewsCorrCoef,
MeanAbsoluteError,
MeanSquaredError,
MetricCollection,
R2Score,
)
from ..data import TwoGraphData
[docs]class BaseModel(LightningModule):
"""Base model, defines a lot of helper functions."""
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters()
self.batch_size = kwargs["datamodule"]["batch_size"]
return kwargs["model"]
def _set_class_metrics(self, num_classes: int = 2):
metrics = MetricCollection(
[
Accuracy(num_classes=None if num_classes == 2 else num_classes),
AUROC(num_classes=None if num_classes == 2 else num_classes),
MatthewsCorrCoef(num_classes=num_classes),
]
)
self.train_metrics = metrics.clone(prefix="train_")
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")
def _set_reg_metrics(self):
metrics = MetricCollection([MeanAbsoluteError(), MeanSquaredError(), ExplainedVariance()])
self.train_metrics = metrics.clone(prefix="train_")
self.val_metrics = metrics.clone(prefix="val_")
self.test_metrics = metrics.clone(prefix="test_")
def _determine_feat_method(
self,
feat_method: str,
drug_hidden_dim: int = None,
prot_hidden_dim: int = None,
**kwargs,
):
"""Which method to use for concatenating drug and protein representations."""
if feat_method == "concat":
self.merge_features = self._concat
self.embed_dim = drug_hidden_dim + prot_hidden_dim
elif feat_method == "element_l2":
assert drug_hidden_dim == prot_hidden_dim
self.merge_features = self._element_l2
self.embed_dim = drug_hidden_dim
elif feat_method == "element_l1":
assert drug_hidden_dim == prot_hidden_dim
self.merge_features = self._element_l1
self.embed_dim = drug_hidden_dim
elif feat_method == "mult":
assert drug_hidden_dim == prot_hidden_dim
self.merge_features = self._mult
self.embed_dim = drug_hidden_dim
else:
raise ValueError("unsupported feature method")
def _concat(self, drug_embed: Tensor, prot_embed: Tensor) -> Tensor:
"""Concatenation."""
return torch.cat((drug_embed, prot_embed), dim=1)
def _element_l2(self, drug_embed: Tensor, prot_embed: Tensor) -> Tensor:
"""L2 distance."""
return torch.sqrt(((drug_embed - prot_embed) ** 2) + 1e-6).float()
def _element_l1(self, drug_embed: Tensor, prot_embed: Tensor) -> Tensor:
"""L1 distance."""
return (drug_embed - prot_embed).abs()
def _mult(self, drug_embed: Tensor, prot_embed: Tensor) -> Tensor:
"""Multiplication."""
return drug_embed * prot_embed
[docs] def training_step(self, data: TwoGraphData, data_idx: int) -> dict:
"""What to do during training step."""
ss = self.shared_step(data)
self.train_metrics.update(ss["preds"], ss["labels"])
self.log("train_loss", ss["loss"], batch_size=self.batch_size)
return ss
[docs] def validation_step(self, data: TwoGraphData, data_idx: int) -> dict:
"""What to do during validation step. Also logs the values for various callbacks."""
ss = self.shared_step(data)
self.val_metrics.update(ss["preds"], ss["labels"])
self.log("val_loss", ss["loss"], batch_size=self.batch_size)
return ss
[docs] def test_step(self, data: TwoGraphData, data_idx: int) -> dict:
"""What to do during test step. Also logs the values for various callbacks."""
ss = self.shared_step(data)
self.test_metrics.update(ss["preds"], ss["labels"])
self.log("test_loss", ss["loss"], batch_size=self.batch_size)
return ss
[docs] def log_histograms(self):
"""Logs the histograms of all the available parameters."""
if self.logger:
for name, param in self.named_parameters():
self.logger.experiment.add_histogram(name, param, self.current_epoch)
[docs] def log_all(self, metrics: dict, hparams: bool = False):
"""Log all metrics."""
if self.logger:
for k, v in metrics.items():
self.logger.experiment.add_scalar(k, v, self.current_epoch)
if hparams:
self.logger.log_hyperparams(self.hparams, {k.split("_")[-1]: v for k, v in metrics.items()})
[docs] def training_epoch_end(self, outputs: dict):
"""What to do at the end of a training epoch. Logs everything."""
self.log_histograms()
metrics = self.train_metrics.compute()
self.train_metrics.reset()
self.log_all(metrics)
[docs] def validation_epoch_end(self, outputs: dict):
"""What to do at the end of a validation epoch. Logs everything."""
metrics = self.val_metrics.compute()
self.val_metrics.reset()
self.log_all(metrics, hparams=True)
[docs] def test_epoch_end(self, outputs: dict):
"""What to do at the end of a test epoch. Logs everything."""
metrics = self.test_metrics.compute()
self.test_metrics.reset()
self.log_all(metrics)