Module kerod.layers.smca.reference_points
None
None
View Source
import tensorflow as tf
from kerod.utils.documentation import remove_unwanted_doc
__pdoc__ = {}
class SMCAReferencePoints(tf.keras.layers.Layer):
"""Multi head reference points from the paper [Fast Convergence of DETR with Spatially Modulated Co-Attention](https://arxiv.org/pdf/2101.07448.pdf).
Based on the object queries will create a set of reference points which will
allow to create a [spatial dynamical weight maps](./weight_map.py) in order to modulate
the co-attention inside the transformer
Arguments:
hidden_dim: Positive integer, dimensionality of the hidden space.
num_heads: Positive integer, each head starts from a head-shared center
and then predicts a head-specific center offset
and head specific scales.
Call arguments:
object_queries: A 3-D float32 Tensor of shape
[batch_size, num_object_queries, d_model] small fixed number of
learned positional embeddings input of the decoder.
Call returns:
Tuple:
- `reference_points`: A float tensor of shape
[batch_size, num_object_queries, num_heads, (y, x, w, h)].
- `embedding_reference_points`: A tensor of shape
[batch_size, num_object_queries, num_heads, 2].
The embedding of y and x without the sigmoid applied.
"""
def __init__(self, hidden_dim: int, num_heads: int, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.xy_embed = tf.keras.models.Sequential([
tf.keras.layers.Dense(hidden_dim, activation='relu'),
tf.keras.layers.Dense(hidden_dim, activation='relu'),
tf.keras.layers.Dense(2) # (y_cent, x_cent)
])
# Each head will have its proper focus weight and width
self.yx_offset_hw_embed = tf.keras.layers.Dense(4 * num_heads)
def call(self, object_queries):
"""
Args:
object_queries: A 3-D float32 Tensor of shape
[batch_size, num_object_queries, d_model] small fixed number of
learned positional embeddings input of the decoder.
Call returns:
Tuple:
- `reference_points`: A float tensor of shape
[batch_size, num_object_queries, num_heads, (y, x, w, h)].
- `embedding_reference_points`: A tensor of shape
[batch_size, num_object_queries, num_heads, 2].
The embedding of y and x without the sigmoid applied.
"""
yx_pre_sigmoid = self.xy_embed(object_queries) # [bs, num_queries, 2]
yx = tf.nn.sigmoid(yx_pre_sigmoid)
# Where y and x are offsets predictions for 'yx' (above)
# h and w are the scales predictions
yx_offset_hw = self.yx_offset_hw_embed(object_queries) # [bs, num_queries, head * 4]
batch_size = tf.shape(object_queries)[0]
num_queries = tf.shape(object_queries)[1]
# Add offset coordinates to yx and concatenate with scale predictions
# yx => [bs, num_queries, head, 2]
yx = tf.tile(yx[:, :, None], (1, 1, self.num_heads, 1))
# yx_offset_hw => [bs, num_queries, head, 4]
yx_offset_hw = tf.reshape(yx_offset_hw, (batch_size, -1, self.num_heads, 4))
yxhw = tf.concat([yx, tf.zeros((batch_size, num_queries, self.num_heads, 2))], axis=-1)
return yxhw + yx_offset_hw, yx_pre_sigmoid
def get_config(self):
config = super().get_config()
config['num_heads'] = self.num_heads
config['hidden_dim'] = self.hidden_dim
return config
remove_unwanted_doc(SMCAReferencePoints, __pdoc__)
Classes
SMCAReferencePoints
class SMCAReferencePoints(
hidden_dim: int,
num_heads: int,
**kwargs
)
Based on the object queries will create a set of reference points which will allow to create a spatial dynamical weight maps in order to modulate the co-attention inside the transformer
Arguments
Name | Description |
---|---|
hidden_dim | Positive integer, dimensionality of the hidden space. |
num_heads | Positive integer, each head starts from a head-shared center and then predicts a head-specific center offset and head specific scales. |
Call arguments
Name | Description |
---|---|
object_queries | A 3-D float32 Tensor of shape [batch_size, num_object_queries, d_model] small fixed number of learned positional embeddings input of the decoder. |
Call returns
Type | Description |
---|---|
Tuple | - reference_points : A float tensor of shape[batch_size, num_object_queries, num_heads, (y, x, w, h)]. - embedding_reference_points : A tensor of shape[batch_size, num_object_queries, num_heads, 2]. The embedding of y and x without the sigmoid applied. |
Ancestors (in MRO)
- tensorflow.python.keras.engine.base_layer.Layer
- tensorflow.python.module.module.Module
- tensorflow.python.training.tracking.tracking.AutoTrackable
- tensorflow.python.training.tracking.base.Trackable
- tensorflow.python.keras.utils.version_utils.LayerVersionSelector
Methods
call
def call(
self,
object_queries
)
Parameters:
Name | Description |
---|---|
object_queries | A 3-D float32 Tensor of shape [batch_size, num_object_queries, d_model] small fixed number of learned positional embeddings input of the decoder. |
View Source
def call(self, object_queries):
"""
Args:
object_queries: A 3-D float32 Tensor of shape
[batch_size, num_object_queries, d_model] small fixed number of
learned positional embeddings input of the decoder.
Call returns:
Tuple:
- `reference_points`: A float tensor of shape
[batch_size, num_object_queries, num_heads, (y, x, w, h)].
- `embedding_reference_points`: A tensor of shape
[batch_size, num_object_queries, num_heads, 2].
The embedding of y and x without the sigmoid applied.
"""
yx_pre_sigmoid = self.xy_embed(object_queries) # [bs, num_queries, 2]
yx = tf.nn.sigmoid(yx_pre_sigmoid)
# Where y and x are offsets predictions for 'yx' (above)
# h and w are the scales predictions
yx_offset_hw = self.yx_offset_hw_embed(object_queries) # [bs, num_queries, head * 4]
batch_size = tf.shape(object_queries)[0]
num_queries = tf.shape(object_queries)[1]
# Add offset coordinates to yx and concatenate with scale predictions
# yx => [bs, num_queries, head, 2]
yx = tf.tile(yx[:, :, None], (1, 1, self.num_heads, 1))
# yx_offset_hw => [bs, num_queries, head, 4]
yx_offset_hw = tf.reshape(yx_offset_hw, (batch_size, -1, self.num_heads, 4))
yxhw = tf.concat([yx, tf.zeros((batch_size, num_queries, self.num_heads, 2))], axis=-1)
return yxhw + yx_offset_hw, yx_pre_sigmoid