Source code for beexplainable.modelling.classifiers

"""Library for constructing CNN classifiers"""

import tensorflow as tf
from tensorflow.keras.layers import Dropout, GlobalAveragePooling2D, Layer
from tensorflow.keras.applications import ResNet50

from typing import Tuple

[docs]class CNN_ResNet50(Layer): def __init__(self, in_shape: Tuple[int], drop_rate: float, add_global_avg: bool = True, trainable: bool = False, mcdrop: bool = False, weights_file: str = None, **kwargs): """Class for a CNN backbone `ResNet50`. :param in_shape: Input shape of the images in the format `(width, height, channels)`. :type in_shape: tuple of 3 integers. :param drop_rate: Dropout rate. :type drop_rate: float :param add_global_avg: Whether to add a `tf.keras.layers.GlobalsAveragePooling2D` layer after the feature extractor. \ Some pretrained CNNs come along with such a layer already. Defaults to *True*. :type add_global_avg: bool :param trainable: Whether the backbone weights should be retrained. Defaults to *False*. :type trainable: bool :param mcdrop: Whether to apply Monte-Carlo Dropout. Defaults to *False*. :type mcdrop: bool :param weights_file: Path to load backbone weights from. Defaults to *None*, in which case ImageNet weights are loaded. :type weights_file: str, optional :param kwargs: Other arguments such as model name. """ super().__init__(**kwargs) self.in_shape = in_shape self.drop_rate = drop_rate self.mcdrop = mcdrop if weights_file is None: self.resnet = ResNet50(include_top = False, weights = 'imagenet', input_shape = in_shape) else: self.resnet = tf.keras.models.load_model(weights_file) self.global_avg = GlobalAveragePooling2D() if add_global_avg else None self.resnet.trainable = trainable
[docs] def call(self, inputs, training: bool = None): """Apply forward propagation to **inputs**. :param inputs: Tensor batch of shape *(batch_size, height, width, 3)*. :type inputs: Tensor :param training: Whether to run in training mode (relevant for Batch Normalization). Defaults to *None*. :type training: bool, optional :return: Features extracted from **inputs**. """ x = self.resnet(inputs, training) if self.global_avg is not None: x = self.global_avg(x) x = Dropout(self.drop_rate, name = 'dropout')(x, training = True if self.mcdrop else training) return x
[docs] def get_config(self): """Get configuration (needed for serialization when extra arguments are provided in ``__init__()``). :return: Configuration dictionary. :rtype: dict """ config = super().get_config() config.update({"in_shape": self.in_shape, "drop_rate": self.drop_rate, "monte_carlo_dropout": self.mcdrop}) return config