Skip to content

Module kerod.utils.training

None

None

View Source
import tensorflow as tf

from typing import Callable

def freeze_layers_before(model: tf.keras.Model, layer_name: str):

    """Freezes layers of a Keras `model` before a given `layer_name` (excluded)."""

    freeze_before = model.get_layer(layer_name)

    index_freeze_before = model.layers.index(freeze_before)

    for layer in model.layers[:index_freeze_before]:

        layer.trainable = False

def freeze_batch_normalization(model: tf.keras.Model):

    """In Object detection we usually do not train on big batch. The BatchNormalization is

    not useful and should be frozen.

    """

    for layer in model.layers:

        if isinstance(layer, tf.keras.layers.BatchNormalization):

            layer.trainable = False

def apply_kernel_regularization(func: Callable, model: tf.keras.Model):

    """Apply kernel regularization on all the trainable layers of a Layer or a Model"""

    for layer in model.layers:

        if hasattr(layer, 'kernel') and layer.trainable:

            model.add_loss(func(layer.kernel))

Functions

apply_kernel_regularization

def apply_kernel_regularization(
    func: Callable,
    model: tensorflow.python.keras.engine.training.Model
)

Apply kernel regularization on all the trainable layers of a Layer or a Model

View Source
def apply_kernel_regularization(func: Callable, model: tf.keras.Model):

    """Apply kernel regularization on all the trainable layers of a Layer or a Model"""

    for layer in model.layers:

        if hasattr(layer, 'kernel') and layer.trainable:

            model.add_loss(func(layer.kernel))

freeze_batch_normalization

def freeze_batch_normalization(
    model: tensorflow.python.keras.engine.training.Model
)

In Object detection we usually do not train on big batch. The BatchNormalization is

not useful and should be frozen.

View Source
def freeze_batch_normalization(model: tf.keras.Model):

    """In Object detection we usually do not train on big batch. The BatchNormalization is

    not useful and should be frozen.

    """

    for layer in model.layers:

        if isinstance(layer, tf.keras.layers.BatchNormalization):

            layer.trainable = False

freeze_layers_before

def freeze_layers_before(
    model: tensorflow.python.keras.engine.training.Model,
    layer_name: str
)

Freezes layers of a Keras model before a given layer_name (excluded).

View Source
def freeze_layers_before(model: tf.keras.Model, layer_name: str):

    """Freezes layers of a Keras `model` before a given `layer_name` (excluded)."""

    freeze_before = model.get_layer(layer_name)

    index_freeze_before = model.layers.index(freeze_before)

    for layer in model.layers[:index_freeze_before]:

        layer.trainable = False