Resnet builder from scratch

When building custom object detection pipelines, there is often a need to build/modify existing backbone networks.

Resnet is a network that is commonly used in such models.

However, once you modify the structure it gets difficult to use pre-trained weights.

In this article I'll try to make the process easier by building a resnet from scratch.

I use the naming convention of pre-trained keras networks to make the transfer easier.

Once the model has been modified as per the usecase, we can copy the remaining weights from a pre-trained model(whichever applicable)

Features:

  • Concise code
  • Simple Architecture
  • Customizable
  • Keras pre-trained weights can be pasted to custom build
  • Modify a resnet as you wish (cut down etc) and still be able to paste weights from a pre-trained resnet on remaining applicable layers!

import tensorflow as tf
import cv2
import numpy as np

from tensorflow.keras.layers import *
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import *

Building blocks

class Basic_Block(tf.keras.layers.Layer):
    """
        X
        |
        |--skip_block
        |     |
        |     conv(3x3)->bn->relu 
        |     |
        |     conv(3x3)->bn->relu
        |- + -|
        |
        v
    RELU( X(residual) + skip_block )

    Basic_Block: resnet-18,34 -> shortcut connections skips 2 layers (skip_block)
    Bottleneck_Block: resnet 50,101 etc ->  shortcut connections skips 3 layers (skip_block)
    
    yet to implement
    """
class Bottleneck_Block(tf.keras.layers.Layer):
    """
        X
        |
        |--skip_block
        |     |
        |     conv(1x1)->bn->relu 
        |     |
        |     conv(3x3)->bn->relu
        |     |
        |     conv(1x1)->bn
        |- + -|
        |
        v
    RELU( X(residual) + skip_block )

    Basic_Block: resnet-18,34 -> shortcut connections skips 2 layers (skip_block)
    Bottleneck_Block: resnet 50,101 etc ->  shortcut connections skips 3 layers (skip_block)

    """

    def __init__(self, num_filters, modify_conv=True, stride=1, idx=[None,None]):
        stage_idx, block_idx = [str(i) for i in idx] 
        super().__init__(name=f'RESNET_BLOCK_{stage_idx}_{block_idx}') #name=f'resnet_block_{stage_idx}_{block_idx}'
        
        self.num_filters = num_filters

        self.conv_name_base =  f'conv{stage_idx}_block{block_idx}_' 
        self.bn_name_base = f'conv{stage_idx}_block{block_idx}_'

        self.conv1 = Conv2D(filters=num_filters, kernel_size=1, strides=1, padding="same", name=self.conv_name_base+'1_conv')
        self.bn1 = BatchNormalization(epsilon=1.001e-5, name=self.bn_name_base+'1_bn')
        
        self.conv2 = Conv2D(filters=num_filters, kernel_size=3, strides=stride, padding="same", name=self.conv_name_base+'2_conv')
        self.bn2 = BatchNormalization(name=self.bn_name_base+'2_bn')

        self.conv3 = Conv2D(filters=num_filters*4, kernel_size=1, strides=1, padding="same", name=self.conv_name_base+'3_conv')
        self.bn3 = BatchNormalization(name=self.bn_name_base+'3_bn')

        self.relu = ReLU()

        self.modify_conv = modify_conv
        if self.modify_conv:
            self.modify_conv_block = tf.keras.Sequential([
                            Conv2D(filters=num_filters*4, kernel_size=1, strides=stride, name=self.bn_name_base+'0_conv'),
                            BatchNormalization(name=self.bn_name_base+'0_bn')
                            ])

                

    def call(self, x, training=None, **kwargs):

        residual = x

        if self.modify_conv:
            residual = self.modify_conv_block(x)

        x = self.conv1(x)
        x = self.bn1(x, training=training)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x, training=training)
        x = self.relu(x)
        
        x = self.conv3(x)
        x = self.bn3(x, training=training)

        out = Add()([residual, x])
        out = self.relu(out)

        
        return out
    
    def get_config(self):
        cfg = super().get_config()
        return cfg 

    def summary(self):
        x = Input(shape=(224,224,self.num_filters*4))
        model = tf.keras.models.Model(inputs=[x], outputs=self.call(x))
        return model.summary()
class Build_ResNet(tf.keras.Model):

    def __init__(self, block_type, num_filters_list = [64,128,256,512], num_blocks=[3,4,6,3] ,num_classes=1000):
        super().__init__()
        self.expansion_factor = 4
        self.num_filters_pre_block = 64
        


        self.resnet_stage_1 = tf.keras.Sequential([
                            Conv2D(filters=64, kernel_size=7, strides=2, padding="valid", name="conv1_conv"),
                            BatchNormalization(name="conv1_bn"),
                            ReLU(),
                            MaxPool2D(pool_size=3, strides=2, padding="same")
                            ])
        
        self.resnet_stage_2 = self._make_resnet_stage(block_type,num_filters_list[0],num_blocks[0],stride=1,stage_idx=2)
        self.resnet_stage_3 = self._make_resnet_stage(block_type,num_filters_list[1],num_blocks[1],stride=2,stage_idx=3)
        self.resnet_stage_4 = self._make_resnet_stage(block_type,num_filters_list[2],num_blocks[2],stride=2,stage_idx=4)
        self.resnet_stage_5 = self._make_resnet_stage(block_type,num_filters_list[3],num_blocks[3],stride=2,stage_idx=5)


        self.avgpool = GlobalAveragePooling2D()
        self.fc = Dense(units=num_classes, activation=tf.keras.activations.softmax, name="predictions")
        
    
    def _make_resnet_stage(self, block_type, num_filters, num_blocks, stride=1, stage_idx=None):
        modify_conv = False

        # tf.print(f"stride:{stride},num_filters:{num_filters},num_filters_pre_block:{self.num_filters_pre_block},modify_conv:{modify_conv}")
        # if stride==2 or num_filters!= self.num_filters_pre_block*self.expansion_factor:
        if stride==2 or self.num_filters_pre_block != num_filters*self.expansion_factor:
            modify_conv=True
            

        residual_blocks = []
        residual_blocks.append(block_type(num_filters, modify_conv, stride, idx=[stage_idx, 1]))
        # tf.print(stage_idx,1, f"num_filters:{num_filters}, modify_conv:{modify_conv}, stride:{stride}")
        self.num_filters_pre_block *= self.expansion_factor 


        modify_conv=False
        for i in range(1,num_blocks):

            residual_blocks.append(block_type(num_filters, modify_conv, stride=1, idx=[stage_idx, i+1]))
            # tf.print(stage_idx,i+1, f"num_filters:{num_filters}, modify_conv:{modify_conv}, stride:{stride}")

        resnet_stage = tf.keras.Sequential(residual_blocks)
        return resnet_stage
        

    def call(self,x):

        x = self.resnet_stage_1(x)
        
        x = self.resnet_stage_2(x, training=True)
        x = self.resnet_stage_3(x, training=True)
        x = self.resnet_stage_4(x, training=True)
        x = self.resnet_stage_5(x, training=True)

        x = self.avgpool(x)
        output = self.fc(x)

        return output

    def get_config(self):
        cfg = super().get_config()
        return cfg 
        
    def summary(self):
        x = Input(shape=(224, 224, 3))
        model = Model(inputs=[x], outputs=self.call(x))
        return model.summary()

Resnet model zoo

class ResnetModels():
    """https://github1s.com/keras-team/keras-applications/blob/bc89834ed36935ab4a4994446e34ff81c0d8e1b7/keras_applications/resnet_common.py#L463"""
    basic_block = Basic_Block
    bottleneck_block = Bottleneck_Block
    bottleneck_block_v2 = None
    num_classes = 1000

    def ResNet50(num_classes=num_classes):
        block_type = ResnetModels.bottleneck_block
        num_filters_list = [64,128,256,512]
        num_blocks=[3,4,6,3]
            
        return Build_ResNet(block_type, 
                      num_filters_list=num_filters_list, 
                      num_blocks=num_blocks, 
                      num_classes=num_classes)

    def ResNet101(num_classes=num_classes):

        block_type = ResnetModels.bottleneck_block
        num_filters_list = [64,128,256,512]
        num_blocks=[3,4,23,3]
        
        return Build_ResNet(block_type, 
                      num_filters_list=num_filters_list, 
                      num_blocks=num_blocks, 
                      num_classes=num_classes)

    def ResNet152(num_classes=num_classes):

        block_type = ResnetModels.bottleneck_block
        num_filters_list = [64,128,256,512]
        num_blocks=[3,8,36,3]
        
        return Build_ResNet(block_type, 
                      num_filters_list=num_filters_list, 
                      num_blocks=num_blocks, 
                      num_classes=num_classes)

    def ResNet50V2(num_classes=num_classes):

        block_type = ResnetModels.bottleneck_block_v2
        num_filters_list = [64,128,256,512]
        num_blocks=[3,4,6,3]
        
        return Build_ResNet(block_type, 
                      num_filters_list=num_filters_list, 
                      num_blocks=num_blocks, 
                      num_classes=num_classes)

    def ResNet101V2(num_classes=num_classes):

        block_type = ResnetModels.bottleneck_block
        num_filters_list = [64,128,256,512]
        num_blocks=[3,4,23,3]
        

    def ResNet152V2(num_classes=num_classes):

        block_type = ResnetModels.bottleneck_block
        num_filters_list = [64,128,256,512]
        num_blocks=[3,8,36,3]
        
        return Build_ResNet(block_type, 
                      num_filters_list=num_filters_list, 
                      num_blocks=num_blocks, 
                      num_classes=num_classes)

    def ResNeXt50(num_classes=num_classes):

        block_type = ResnetModels.bottleneck_block
        num_filters_list = [128,256,512, 1024]
        num_blocks=[3,4,6,3]
        
        return Build_ResNet(block_type, 
                      num_filters_list=num_filters_list, 
                      num_blocks=num_blocks, 
                      num_classes=num_classes)

    def ResNeXt101(num_classes=num_classes):
        block_type = ResnetModels.bottleneck_block
        num_filters_list = [64,128,256,512]
        num_blocks=[3,4,23,3]
        
        return Build_ResNet(block_type, 
                        num_filters_list=num_filters_list, 
                        num_blocks=num_blocks, 
                        num_classes=num_classes)

Example

model = ResnetModels.ResNet50(num_classes=1000)
model.summary() #please ensure to run this line to build the model
pre_trained_model = tf.keras.applications.ResNet50()
Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_8 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
sequential_46 (Sequential)   (None, 55, 55, 64)        9728      
_________________________________________________________________
sequential_48 (Sequential)   (None, 55, 55, 256)       220032    
_________________________________________________________________
sequential_50 (Sequential)   (None, 28, 28, 512)       1230336   
_________________________________________________________________
sequential_52 (Sequential)   (None, 14, 14, 1024)      7129088   
_________________________________________________________________
sequential_54 (Sequential)   (None, 7, 7, 2048)        14998528  
_________________________________________________________________
global_average_pooling2d_4 ( (None, 2048)              0         
_________________________________________________________________
predictions (Dense)          (None, 1000)              2049000   
=================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
_________________________________________________________________
assert len(pre_trained_model.get_weights())==len(model.get_weights()), "dim_check"

Modifying as per usecase

Let us use only 1st 2 sequential blocks of the resnet-50

The remaming layers are discarded.

After the modification, we will copy the weights of 1st 2 blocks from pre-trained n/w automatically

model.layers
[<tensorflow.python.keras.engine.sequential.Sequential at 0x7f506cf9b810>,
 <tensorflow.python.keras.engine.sequential.Sequential at 0x7f506d9d5b10>,
 <tensorflow.python.keras.engine.sequential.Sequential at 0x7f506ad5d510>,
 <tensorflow.python.keras.engine.sequential.Sequential at 0x7f506d88fb50>,
 <tensorflow.python.keras.engine.sequential.Sequential at 0x7f506ad7f910>,
 <tensorflow.python.keras.layers.pooling.GlobalAveragePooling2D at 0x7f506ad66ed0>,
 <tensorflow.python.keras.layers.core.Dense at 0x7f506ad75350>]
model= tf.keras.models.Sequential(model.layers[0:2])
x = np.random.normal(size=(1, 224, 224, 3))
x = tf.convert_to_tensor(x)
model(x)
model.summary()
Model: "sequential_55"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_46 (Sequential)   (None, 55, 55, 64)        9728      
_________________________________________________________________
sequential_48 (Sequential)   (None, 55, 55, 256)       220032    
=================================================================
Total params: 229,760
Trainable params: 226,816
Non-trainable params: 2,944
_________________________________________________________________

Notice the number of parameters have reduced

dict_1 = {}
for layer in pre_trained_model.layers:
    # print(layer.name)
    # print("*"*80)
    for weight in (layer.weights):
        name = weight.name
        # layer_type = name.split('/')[-1][:-2]
        # print(name, weight.numpy().shape)
        dict_1[name] = ( name, weight.numpy() )

# for key,value in dict_1.items():
#     print(key, len(value))
dict_2 = {"kernel":[],
         "bias":[],
         "gamma":[],
         "beta":[],
         "moving_mean":[],
         "moving_variance":[],
         }
for layer in model.layers:
    # print(layer.name)
    # print("*"*80)
    for weight in (layer.weights):
        name = weight.name
        layer_type = name.split('/')[-1][:-2]
        # print(name, weight.shape)
        dict_2[layer_type].append((name, weight))

# for key,value in dict_2.items():
#     print(key, len(value))

Copy the weights from pre-trained model

for layer in model.layers:
    # print(layer.name)
    # print("*"*80)
    for weight in (layer.weights):
        name = weight.name

        if "RESNET_BLOCK" in name:
            key = "/".join(name.split("/")[1:]) # residual_block_v2_49/conv5_block1_2_bn/gamma:0 -> conv5_block1_2_bn/gamma:0
        else:
            key = name
        # print(key, dict_1[key][1].shape)
        pretrained_weight = dict_1[key][1]
        # print(pretrained_weight)
        print(weight.shape, pretrained_weight.shape)
        weight.assign(pretrained_weight)
        # dict_2[layer_type].append((name, weight))
(7, 7, 3, 64) (7, 7, 3, 64)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(1, 1, 64, 64) (1, 1, 64, 64)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(3, 3, 64, 64) (3, 3, 64, 64)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(1, 1, 64, 256) (1, 1, 64, 256)
(256,) (256,)
(256,) (256,)
(256,) (256,)
(1, 1, 64, 256) (1, 1, 64, 256)
(256,) (256,)
(256,) (256,)
(256,) (256,)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(256,) (256,)
(256,) (256,)
(256,) (256,)
(256,) (256,)
(1, 1, 256, 64) (1, 1, 256, 64)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(3, 3, 64, 64) (3, 3, 64, 64)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(1, 1, 64, 256) (1, 1, 64, 256)
(256,) (256,)
(256,) (256,)
(256,) (256,)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(256,) (256,)
(256,) (256,)
(1, 1, 256, 64) (1, 1, 256, 64)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(3, 3, 64, 64) (3, 3, 64, 64)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(1, 1, 64, 256) (1, 1, 64, 256)
(256,) (256,)
(256,) (256,)
(256,) (256,)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(64,) (64,)
(256,) (256,)
(256,) (256,)
model.summary()
Model: "sequential_55"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential_46 (Sequential)   (None, 55, 55, 64)        9728      
_________________________________________________________________
sequential_48 (Sequential)   (None, 55, 55, 256)       220032    
=================================================================
Total params: 229,760
Trainable params: 226,816
Non-trainable params: 2,944
_________________________________________________________________