David Molony

Machine Learning, Data Science, Medical Imaging

Porting a pretrained ResNet from Pytorch to Tensorflow 2.0


While building a SimCLR model in Tensorflow 2.0 I decided I wanted to use a smaller model than the authors due to my GPU limits. The SimCLR paper uses a ResNet with 50 layers so I decided to use a less resource intense ResNet18 or ResNet34. To my surprise Tensorflow did not have pretrained ImageNet weights for either of these smaller models. On the other hand the torchvision library for Pytorch provides pretrained weights for all ResNets with 18, 34, 50, 101 and 152 layers. Since I already decided to use Tensorflow for this project I set out to port the model and weights from Pytorch to Tensorflow. In this post I will describe this process using tensorflow 2.0.0b1, pytorch 1.4.0 and torchvision 0.5.0.

Pytorch model exploration

The first step is to import resnet from torchvision. We then display the model parameters model.state_dict which shows us the kernel_size and padding used for each layer. Then we place the names of each layer with parameters/weights in a list torch_layer_names.

import torchvision.models as models
import torch
import tensorflow as tf
import numpy as np

resnet_torch = models.resnet18(pretrained=True)
resnet_torch.state_dict

torch_layer_names = []
for name, module in resnet_torch.named_modules():
    torch_layer_names.append(name)

Create ResNet in Tensorflow

With the above knowledge of the model parameters we then create the ResNet model in Tensorflow. We refer also to the original ResNet paper to fully implement the model as our torch_layer_names list only contains layers with parameters so will be missing layer such as the residual connection. For each of the layers in torch_layer_names we make sure the corresponding layer in our Tensorflow model has the same name by setting the name argument. This whole process is mostly straightforward except for a few cases where the padding is different between the models.

Tensorflow padding

Same padding refers to when we want the layer output to have the same size as the input. The Pytorch model uses same padding but naively setting the padding to same in Tensorflow will not work in some layers. I did not comprehensively study the cases where it results in a different output but it seems to occur when stride>1. To get around this we insert a zero padding layer prior to our strided layer where we now use valid padding. For example for a max pool layer we can do the following

x = tf.keras.layers.ZeroPadding2D(padding=(1,1), name='pad1')(x)
x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='valid', name='maxpool')(x)

or for a convolutional layer

x = tf.keras.layers.ZeroPadding2D(padding=(3,3), name='pad')(inputs)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=7, strides=2, padding='valid', activation='linear', use_bias=False, name='conv1')(x)

Generating the model is straightforward from here, again in order to make things easy for ourselves we just need to take care of naming all our corresponding layers in Tensorflow with the same name as our Pytorch model. The below code generates the full ResNet model (note this model does not include bottleneck residual layers as used by the 50 layer model).

def BasicBlock(inputs, num_channels, kernel_size, num_blocks, skip_blocks, name):
    """Basic residual block"""
    x = inputs

    for i in range(num_blocks):
        if i not in skip_blocks:
            x1 = ConvNormRelu(x, num_channels, kernel_size, strides=[1,1], name=name + '.'+str(i))
            x = tf.keras.layers.Add()([x, x1])
            x = tf.keras.layers.Activation('relu')(x)
    return x

def BasicBlockDown(inputs, num_channels, kernel_size, name):
    """Residual block with strided downsampling"""
    x = inputs
    x1 = ConvNormRelu(x, num_channels, kernel_size, strides=[2,1], name=name+'.0')
    x = tf.keras.layers.Conv2D(num_channels, kernel_size=1, strides=2, padding='same', activation='linear', use_bias=False, name=name+'.0.downsample.0')(x)
    x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name+'.0.downsample.1')(x)
    x = tf.keras.layers.Add()([x, x1])
    x = tf.keras.layers.Activation('relu')(x)
    return x

def ConvNormRelu(x, num_channels, kernel_size, strides, name):
    """Layer consisting of 2 consecutive batch normalizations with 1 first relu"""
    if strides[0] == 2:
        x = tf.keras.layers.ZeroPadding2D(padding=(1,1), name=name+'.pad')(x)
        x = tf.keras.layers.Conv2D(num_channels, kernel_size, strides[0], padding='valid', activation='linear', use_bias=False, name=name+'.conv1')(x)
    else:
        x = tf.keras.layers.Conv2D(num_channels, kernel_size, strides[0], padding='same', activation='linear',  use_bias=False, name=name+'.conv1')(x)      
    x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name+'.bn1')(x)
    x = tf.keras.layers.Activation('relu')(x)

    x = tf.keras.layers.Conv2D(num_channels, kernel_size, strides[1], padding='same', activation='linear', use_bias=False, name=name+'.conv2')(x)
    x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name+'.bn2')(x)
    return x

def ResNet18(inputs):
    x = tf.keras.layers.ZeroPadding2D(padding=(3,3), name='pad')(inputs)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=7, strides=2, padding='valid', activation='linear', use_bias=False, name='conv1')(x)
    x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name='bn1')(x)
    x = tf.keras.layers.Activation('relu', name='relu')(x)
    x = tf.keras.layers.ZeroPadding2D(padding=(1,1), name='pad1')(x)
    x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='valid', name='maxpool')(x)

    x = BasicBlock(x, num_channels=64, kernel_size=3, num_blocks=2, skip_blocks=[], name='layer1')

    x = BasicBlockDown(x, num_channels=128, kernel_size=3, name='layer2')
    x = BasicBlock(x, num_channels=128, kernel_size=3, num_blocks=2, skip_blocks=[0], name='layer2')

    x = BasicBlockDown(x, num_channels=256, kernel_size=3, name='layer3')
    x = BasicBlock(x, num_channels=256, kernel_size=3, num_blocks=2, skip_blocks=[0], name='layer3')

    x = BasicBlockDown(x, num_channels=512, kernel_size=3, name='layer4')
    x = BasicBlock(x, num_channels=512, kernel_size=3, num_blocks=2, skip_blocks=[0], name='layer4')

    x = tf.keras.layers.GlobalAveragePooling2D(name='avgpool')(x)
    x = tf.keras.layers.Dense(units=1000, use_bias=True, activation='linear', name='fc')(x)

    return x

Copying the Pytorch weights to Tensorflow

Now that we have a ResNet18 Tensorflow model we need to copy the pretrained weights from the Pytorch model to the Tensorflow model.

Getting Pytorch weights and setting Tensorflow weights

To get weights from a Pytorch layer we can again use the state_dict which returns an ordered dictionary. We then use the layer names as the key but also append the type of weights stored in the layer. For example, to get the parameters for a batch normalization layer.

layer='layer2.0.bn2'
gamma = resnet_torch.state_dict()[layer+'.weight'].numpy()
beta = resnet_torch.state_dict()[layer+'.bias'].numpy()
mean = resnet_torch.state_dict()[layer+'.running_mean'].numpy()
var = resnet_torch.state_dict()[layer+'.running_var'].numpy()

To set the weights for the corresponding batch norm layer in Tensorflow we first get the appropriate layer using the get_layer method of the model class. We then combine all the Pytorch parameters into a list and use the set_weights() method to set them in the Tensorflow model.

inputs = tf.keras.Input((None, None,3))
resnet_tf = ResNet18(inputs)
model = tf.keras.Model(inputs, resnet_tf)
model.get_layer(layer).set_weights([gamma, beta, mean, var]))

Since we have given the Tensorflow layers with parameters the same name as their Pytorch counterparts we can run a simple for loop over the layer names and set the layer weights for the entire model this way.

inputs = tf.keras.Input((None, None,3))
resnet_tf = ResNet18(inputs)
model = tf.keras.Model(inputs, resnet_tf)

tf_layer_names = [layer.name for layer in model.layers]
tf_layer_names = [layer for layer in tf_layer_names if layer in torch_layer_names]

for layer in tf_layer_names:
    if 'conv' in layer:
        tf_conv = model.get_layer(layer)
        weights = resnet_torch.state_dict()[layer+'.weight'].numpy()
        weights_list = [weights.transpose((2, 3, 1, 0))]
        if len(tf_conv.weights) == 2:
            bias = resnet_torch.state_dict()[layer+'.bias'].numpy()
            weights_list.append(bias)
        tf_conv.set_weights(weights_list)
    elif 'bn' in layer:
        tf_bn = model.get_layer(layer)
        gamma = resnet_torch.state_dict()[layer+'.weight'].numpy()
        beta = resnet_torch.state_dict()[layer+'.bias'].numpy()
        mean = resnet_torch.state_dict()[layer+'.running_mean'].numpy()
        var = resnet_torch.state_dict()[layer+'.running_var'].numpy()
        bn_list = [gamma, beta, mean, var]
        tf_bn.set_weights(bn_list)
    elif 'downsample.0' in layer:
        tf_downsample = model.get_layer(layer)
        weights = resnet_torch.state_dict()[layer+'.weight'].numpy()
        weights_list = [weights.transpose((2, 3, 1, 0))]
        if len(tf_downsample.weights) == 2:
            bias = resnet_torch.state_dict()[layer+'.bias'].numpy()
            weights_list.append(bias)
        tf_downsample.set_weights(weights_list)
    elif 'downsample.1' in layer:
        tf_downsample = model.get_layer(layer)
        gamma = resnet_torch.state_dict()[layer+'.weight'].numpy()
        beta = resnet_torch.state_dict()[layer+'.bias'].numpy()
        mean = resnet_torch.state_dict()[layer+'.running_mean'].numpy()
        var = resnet_torch.state_dict()[layer+'.running_var'].numpy()
        bn_list = [gamma, beta, mean, var] # [gamma, beta, mean, var]
        tf_downsample.set_weights(bn_list)
    elif 'fc' in layer:
        tf_fc = model.get_layer(layer)
        weights = resnet_torch.state_dict()[layer+'.weight'].numpy() 
        weights_list = [weights.transpose((1, 0))]
        if len(tf_fc.weights) == 2:
            bias = resnet_torch.state_dict()[layer+'.bias'].numpy()
            weights_list.append(bias)
        tf_fc.set_weights(weights_list)
    else:
        print('No parameters found for {}'.format(layer))

Comparing the model output

The final step is to check that the output from Tensorflow matches that of Pytorch. Let's download an image of a cat and feed it into both models. We will see the max difference in the logits values is very small indicating that the models should be equivalent.

import requests, shutil, PIL
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/9/97/Kot_z_mysz%C4%85.jpg/480px-Kot_z_mysz%C4%85.jpg"
resp = requests.get(image_url, stream=True)
local_file = open('cat.jpg', 'wb')
shutil.copyfileobj(resp.raw, local_file)
img = np.expand_dims(np.array(PIL.Image.open('cat.jpg', 'r')), 0).astype(np.float32)
img_torch = torch.tensor(img.transpose((0, 3, 1, 2)))
tf_output = model.predict(img)
resnet_torch.eval()
torch_output = resnet_torch(img_torch)

max_diff = np.max(np.abs(tf_output - torch_output.detach().numpy()))
print('Max difference in fully connected layer :{}'.format(max_diff))

Save the Tensorflow model

We can save the Tensorflow model as follows

model.save('Resnet18')

Conclusion

Hopefully this post wil be useful to someone if they need to use Resnet18 or ResNet34 for Tensorflow or decide to port another Pytorch model to Tensorflow. The code above should work out of the box or refer to my github.