Source code for tf_al.wrapper.model

import os, math
import uuid
import numpy as np
from enum import Enum
import tensorflow as tf

from . import Checkpoint
from ..utils import setup_logger, ProblemUtils
from ..stats import get


[docs]class Model: """ Base wrapper for deep learning models to interface with the active learning environment. Attributes: _model (tf.Model): Tensorflow or pytorch module. _config (Config): Model configuration _mode (Mode): The mode the model is in 'train' or 'test'/'eval'. _model_type (str): The model type _checkpoints (Checkpoint): Created checkpoints. Parameters: model (tf.Model): The tensorflow model to be used. config (Config): Configuration object for the model. (default=None) is_binary (bool): classification (bool): """ def __init__( self, model, config=None, name=None, model_type=None, checkpoint=None, verbose=False, checkpoint_path=None, **kwargs ): self.__verbose = verbose self.logger = setup_logger(verbose, "Model Logger") self.__id = uuid.uuid1() self._model = model self._config = config self._model_type = model_type self._name = name self._compile_params = None self._problem = ProblemUtils( kwargs.get("classification", True), kwargs.get("is_binary", False) ) self.eval_metrics = [] # Checkpoints path set? if checkpoint_path is None: checkpoint_path = os.getcwd() self._checkpoints = Checkpoint(checkpoint_path) if checkpoint is None else checkpoint def __call__(self, inputs, **kwargs): return self._model(inputs, **kwargs)
[docs] def predict(self, inputs, **kwargs): """ Approximate predictive distribution. Parameter: inputs (numpy.ndarray): The inputs for the approximation """ return self._model.predict(inputs, **kwargs)
[docs] def evaluate(self, inputs, targets, **kwargs): """ Evaluate a model on given input data and targets. Parameters: inputs (numpy.ndarray): targets (numpy.ndarray): Returns: (list) A list with two values. [loss, accuracy] """ pred_targets = self.predict(inputs, **kwargs) output_metrics = {} for metric in self.eval_metrics: if isinstance(metric, str): metric = tf.keras.metrics.get(metric) metric_name = None if hasattr(metric, "__name__"): metric_name = metric.__name__ else: metric_name = metric.name output_metrics[metric_name] = metric(targets, pred_targets) return output_metrics
[docs] def fit(self, *args, **kwargs): """ Fit the model to the given data. Args: x (numpy.ndarray): The inputs to train the model on. (default=None) y (numpy.ndarray): The targets to fit the model to. (default=None) batch_size (int): The size of each individual batch Returns: () a record of the trianing procedure """ if self._config is not None and "fit" in self._config and isinstance(self._config["fit"], dict): fit_params = self._config["fit"] kwargs.update(fit_params) return self._model.fit(*args, **kwargs)
[docs] def compile(self, *args, **kwargs): """ Compile the model if needed """ self._compile_params = kwargs self._model.compile(**kwargs) # Create evaluation metrics from compile parameters metrics = self._create_init_metrics(kwargs) if self.eval_metrics == []: self.eval_metrics = metrics
[docs] def disable_batch_norm(self): """ Disable batch normalization for activation of dropout during prediction. Parameters: - model (tf.Model) Tensorflow neural network model. """ disabled = False for l in self._model.layers: if l.__class__.__name__ == "BatchNormalization": disabled = True l.trainable = False if disabled: self.logger.info("Disabled BatchNorm-Layers.")
def clear_session(self): tf.keras.backend.clear_session() # ------ # Model runtime configurations # -----------------------------
[docs] def reset(self, pool, dataset): """ Use to reset states, weights and other stuff after each active learning loop iteration. Parameters: pool (Pool): The pool managing labeled and unlabeled indices. dataset (Dataset): The dataset containting the different splits. """ self._model = tf.keras.models.clone_model(self._model) self._model.compile(**self._compile_params)
[docs] def optimize(self, inputs, targets): """ Use to perform optimization during active learning loop. """ pass
# ---------------------- # Utilities # ------------------------------
[docs] def batch_prediction(self, inputs, batch_size=1, **kwargs): """ Parameters: inputs (numpy.ndarray): Inputs going into the model n_times (int): How many times to sample from posterior? batch_size (int): In how many batches to split the data? """ if batch_size < 1: raise ValueError("Error in McDropout.__call__(). Can't select negative amount of batches.") total_len = len(inputs) num_batches = math.ceil(total_len/batch_size) batches = np.array_split(inputs, num_batches, axis=0) predictions = [] for batch in batches: predictions.append(self._model(batch, training=True)) return np.vstack(predictions)
def _create_init_metrics(self, kwargs): """ Extract loss and other metrics passed to Model.compile() and transfer them to metrics which should be used for evaluation. """ loss = kwargs.get("loss", None) metrics = kwargs.get("metrics", []) if not isinstance(metrics, list): metrics = [metrics] if loss is not None: metrics = [loss] + metrics return metrics def _extract_metric_names(self, metrics): """ Get the metric names from a list of metric classes, passing already existing metric names on. Parameters: metrics (list()): A list can contain strings and metric objects. Returns: (list(str)) a list of metric names """ all_metrics = [] for metric in metrics: metric_name = metric if not isinstance(metric, str): metric_name = metric.name all_metrics.append(metric_name) return all_metrics def _init_metrics(self, prefix, metrics): """ Initializes metrics passed to the object during compilation. Using model specific metrics instead. Parameters: prefix (str): The sub-package of tf_al.stats to be used (sampling, stochastic, ...) metrics (list(str)): A list of metric names. Returns: (list) of initialized metrics """ initialized_metrics = [] for metric in metrics: initialized_metrics.append(get(prefix, metric)) return initialized_metrics # -------------- # Checkpoint creation/loading # ------------------------------ def empty_checkpoint(self): return self._checkpoints.empty() def new_checkpoint(self): self._checkpoints.new(self._model) def load_checkpoint(self, iteration=None): self._checkpoints.load(self._model, iteration) def clear_checkpoints(self): self._checkpoints.clean() def save_weights(self): path = self._checkpoints.PATH self._model.save_weights(path) def load_weights(self): path = self._checkpoints.PATH self._model.load_weights(path) def has_save_state(self): try: self.load_weights() return True except: return False # ----------- # Access Configuration # ---------------------------- def get_config_for(self, name): if (self._config is not None) and name in self._config: return self._config[name] return {} def get_config(self): return self._config.kwargs # -------------- # Access important flags for predictions # ---------------------------- def is_classification(self): return self._problem.classification def is_binary(self): return self._problem.is_binary # --------------- # Acquisition functions # --------------------------
[docs] def get_query_fn(self, name): """ Get model specific acquisition function. Parameters: name (str): The name of the acquisition function to return. Returns: (function) the acquisition function to use. """ pass
def __max_entropy(self, data, **kwargs): pass def __bald(self, data, **kwargs): pass def __max_var_ratio(self, data, **kwargs): pass def __std_mean(self, data, **kwargs): pass # ------------- # Metric hooks # ----------------- # def _on_evaluate_loss(self, **kwargs): # pass # def _on_evaluate_acc(self, **kwargs): # pass # ----------------- # Setter/-Getter # -------------------------- def get_id(self): return str(self.__id)
[docs] def get_model_name(self, prefix=True): """ Returns the model name. Parameters: prefix (bool): Prefix the model name with model type? Returns: (str) the model name. """ # Model name and model type? model_name_exists = self._name is not None if self._model_type is not None: if model_name_exists and prefix: return self._model_type.lower() + "_" + self._name return self._model_type.lower() # Only model name given if model_name_exists: return self._name # Default model name return "model"
def get_model(self): return self._model def get_metric_names(self): return self._model.metrics_names def get_base_model(self): return self._model # --------------- # Dunder # ---------------------- def __str__(self): return self.get_model_name() def __getstate__(self): """ Called when pickeling the object. Remove logger, because no use to pickle logger. """ d = dict(self.__dict__) del d["logger"] return d def __setstate__(self, d): self.__dict__.update(d) verbose = d["__verbose"] self.logger = setup_logger(verbose, "Model Logger")