import collections
import itertools
from typing import Any, Callable, Dict, Union
import git
import yaml
[docs]def remove_arg_prefix(prefix: str, kwargs: dict) -> dict:
"""Removes the prefix from all the args.
Args:
prefix (str): prefix to remove (`drug_`, `prot_` or `mlp_` usually)
kwargs (dict): dict of arguments
Returns:
dict: Sub-dict of arguments
"""
new_kwargs = {}
prefix_len = len(prefix)
for key, value in kwargs.items():
if key.startswith(prefix):
new_key = key[prefix_len:]
if new_key == "x_batch":
new_key = "batch"
new_kwargs[new_key] = value
return new_kwargs
[docs]def add_arg_prefix(prefix: str, kwargs: dict) -> dict:
"""Adds the prefix to all the args. Removes None values and "index_mapping".
Args:
prefix (str): prefix to add (`drug_`, `prot_` or `mlp_` usually)
kwargs (dict): dict of arguments
Returns:
dict: Sub-dict of arguments
"""
return {prefix + k: v for (k, v) in kwargs.items() if k != "index_mapping" and v is not None}
[docs]def read_config(filename: str) -> dict:
"""Read in yaml config for training."""
with open(filename, "r") as file:
config = yaml.load(file, Loader=yaml.FullLoader)
return config
[docs]def write_config(filename: str, config: dict) -> None:
"""Write a config to a file."""
with open(filename, "w") as file:
yaml.dump(config, file)
def _tree():
"""Defaultdict of defaultdicts"""
return collections.defaultdict(_tree)
[docs]class IterDict:
"""Returns a list of dicts with all possible combinations of hyperparameters."""
def __init__(self):
self.current_path = []
self.flat = {}
def _flatten(self, d: dict):
for k, v in d.items():
self.current_path.append(k)
if isinstance(v, dict):
self._flatten(v)
else:
self.flat[",".join(self.current_path)] = v
self.current_path.pop()
def _get_variants(self):
configs = []
hparams_small = {k: v for k, v in self.flat.items() if isinstance(v, list)}
if hparams_small == {}:
return [self.flat]
keys, values = zip(*hparams_small.items())
for v in itertools.product(*values):
config = self.flat.copy()
config.update(dict(zip(keys, v)))
configs.append(config)
return configs
def _unflatten(self, d: dict):
root = _tree()
for k, v in d.items():
parts = k.split(",")
curr = root
for part in parts[:-1]:
curr = curr[part]
part = parts[-1]
curr[part] = v
return root
def __call__(self, d: dict):
self._flatten(d)
variants = self._get_variants()
return [self._unflatten(v) for v in variants]
[docs]def recursive_apply(ob: Union[Dict, Any], func: Callable) -> Union[Dict, Any]:
"""Apply a function to the nested dict recursively."""
if isinstance(ob, dict):
return {k: recursive_apply(v, func) for k, v in ob.items()}
else:
return func(ob)
[docs]def get_git_hash():
"""Get the git hash of the current repository."""
repo = git.Repo(search_parent_directories=True)
return repo.head.object.hexsha