8000 Pytorch to tensorflow inference results are different · Issue #595 · microsoft/MMdnn · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Pytorch to tensorflow inference results are different #595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
EnisBerk opened this issue Feb 25, 2019 · 1 comment
Closed

Pytorch to tensorflow inference results are different #595

EnisBerk opened this issue Feb 25, 2019 · 1 comment

Comments

@EnisBerk
Copy link
EnisBerk commented Feb 25, 2019

Hi,
I am trying to convert tensorflow VGGish to pytorch.
You can see step by step what I am doing in that google colab notebook

Problems:

  • MMdnn handling NWHC order wrong. If I use master branch I got reshape(1,W,H,1), using latest release I got reshape(-1,W,H,1). (on the notebook, I am using latest master).
  • When I resolve that problem model expects float32 input and does not do casting itself. (that might be expected)
  • My biggest problem is results are different. You can see in the notebook I shared, checksums are different. Is that expected ?

I am using google Colab:
Platform (like ubuntu 16.04/win10):
Ubuntu version: "17.10 (Artful Aardvark)"
Python version:
3.6.3
Source framework with version (like Tensorflow 1.4.1 with GPU):
Tensorflow 1.13.0-rc1 with GPU
Destination framework with version (like CNTK 2.3 with GPU):
Pytorch 1.0.1 with GPU

Pre-trained model path (webpath or webdisk path):
https://github.com/tensorflow/models/tree/master/research/audioset
https://storage.googleapis.com/audioset/vggish_model.ckpt
Running scripts:
colab research google
mirror on github

@rainLiuplus
Copy link
Contributor

Hi @EnisBerk ,

MMdnn' s IR format is NHWC, but pytorch only support NCHW. MMdnn adopts the method that transpose the input data from the beginning, which is obviously not completely correct. In your case, your model includes flatten operator and the successor op maxpool is expanded according to NCHW. However, the weight to be multiplied by maxpool is still in the format of NHWC in tensorflow, thus causing different results.

And we can't use the method of transpose on both sides of the cnn-related op to solve the problem, because if transpose operations followed by reshape operations, it needs to change the tensor into a continuous type, which will also lead to inaccurate results.

For your situation, you can manually add the permute operations to solve the problem like this:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

__weights_dict = dict()

def load_weights(weight_file):
    if weight_file == None:
        return

    try:
        weights_dict = np.load(weight_file).item()
    except:
        weights_dict = np.load(weight_file, encoding='bytes').item()

    return weights_dict

class KitModel(nn.Module):

    
    def __init__(self, weight_file):
        super(KitModel, self).__init__()
        global __weights_dict
        __weights_dict = load_weights(weight_file)

        self.vggish_conv1_Conv2D = self.__conv(2, name='vggish/conv1/Conv2D', in_channels=1, out_channels=64, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
        self.vggish_conv2_Conv2D = self.__conv(2, name='vggish/conv2/Conv2D', in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
        self.vggish_conv3_conv3_1_Conv2D = self.__conv(2, name='vggish/conv3/conv3_1/Conv2D', in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
        self.vggish_conv3_conv3_2_Conv2D = self.__conv(2, name='vggish/conv3/conv3_2/Conv2D', in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
        self.vggish_conv4_conv4_1_Conv2D = self.__conv(2, name='vggish/conv4/conv4_1/Conv2D', in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
        self.vggish_conv4_conv4_2_Conv2D = self.__conv(2, name='vggish/conv4/conv4_2/Conv2D', in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), groups=1, bias=True)
        self.vggish_fc1_fc1_1_MatMul = self.__dense(name = 'vggish/fc1/fc1_1/MatMul', in_features = 12288, out_features = 4096, bias = True)
        self.vggish_fc1_fc1_2_MatMul = self.__dense(name = 'vggish/fc1/fc1_2/MatMul', in_features = 4096, out_features = 4096, bias = True)
        self.vggish_fc2_MatMul = self.__dense(name = 'vggish/fc2/MatMul', in_features = 4096, out_features = 128, bias = True)

    def forward(self, x):
        self.vggish_Flatten_flatten_Reshape_shape_1 = torch.autograd.Variable(torch.Tensor([-1]), requires_grad=False)
        vggish_Reshape  = torch.reshape(input = x, shape = (1,96,64,1)).permute(0, 3, 1, 2)

        vggish_conv1_Conv2D_pad = F.pad(vggish_Reshape, (1, 1, 1, 1))
        vggish_conv1_Conv2D = self.vggish_conv1_Conv2D(vggish_conv1_Conv2D_pad)
        vggish_conv1_Relu = F.relu(vggish_conv1_Conv2D)
        vggish_pool1_MaxPool = F.max_pool2d(vggish_conv1_Relu, kernel_size=(2, 2), stride=(2, 2), padding=0, ceil_mode=False)
        vggish_conv2_Conv2D_pad = F.pad(vggish_pool1_MaxPool, (1, 1, 1, 1))
        vggish_conv2_Conv2D = self.vggish_conv2_Conv2D(vggish_conv2_Conv2D_pad)
        vggish_conv2_Relu = F.relu(vggish_conv2_Conv2D)
        vggish_pool2_MaxPool = F.max_pool2d(vggish_conv2_Relu, kernel_size=(2, 2), stride=(2, 2), padding=0, ceil_mode=False)
        vggish_conv3_conv3_1_Conv2D_pad = F.pad(vggish_pool2_MaxPool, (1, 1, 1, 1))
        vggish_conv3_conv3_1_Conv2D = self.vggish_conv3_conv3_1_Conv2D(vggish_conv3_conv3_1_Conv2D_pad)
        vggish_conv3_conv3_1_Relu = F.relu(vggish_conv3_conv3_1_Conv2D)
        vggish_conv3_conv3_2_Conv2D_pad = F.pad(vggish_conv3_conv3_1_Relu, (1, 1, 1, 1))
        vggish_conv3_conv3_2_Conv2D = self.vggish_conv3_conv3_2_Conv2D(vggish_conv3_conv3_2_Conv2D_pad)
        vggish_conv3_conv3_2_Relu = F.relu(vggish_conv3_conv3_2_Conv2D)
        vggish_pool3_MaxPool = F.max_pool2d(vggish_conv3_conv3_2_Relu, kernel_size=(2, 2), stride=(2, 2), padding=0, ceil_mode=False)
        vggish_conv4_conv4_1_Conv2D_pad = F.pad(vggish_pool3_MaxPool, (1, 1, 1, 1))
        vggish_conv4_conv4_1_Conv2D = self.vggish_conv4_conv4_1_Conv2D(vggish_conv4_conv4_1_Conv2D_pad)
        vggish_conv4_conv4_1_Relu = F.relu(vggish_conv4_conv4_1_Conv2D)
        vggish_conv4_conv4_2_Conv2D_pad = F.pad(vggish_conv4_conv4_1_Relu, (1, 1, 1, 1))
        vggish_conv4_conv4_2_Conv2D = self.vggish_conv4_conv4_2_Conv2D(vggish_conv4_conv4_2_Conv2D_pad)
        vggish_conv4_conv4_2_Relu = F.relu(vggish_conv4_conv4_2_Conv2D)
        vggish_pool4_MaxPool = F.max_pool2d(vggish_conv4_conv4_2_Relu, kernel_size=(2, 2), stride=(2, 2), padding=0, ceil_mode=False).permute(0, 2, 3, 1)
        vggish_Flatten_flatten_Shape = torch.Tensor(list(vggish_pool4_MaxPool.size()))
        vggish_Flatten_flatten_Reshape = torch.reshape(input = vggish_pool4_MaxPool, shape = (1,12288))
        vggish_Flatten_flatten_strided_slice = vggish_Flatten_flatten_Shape[0:1][0]
        vggish_fc1_fc1_1_MatMul = self.vggish_fc1_fc1_1_MatMul(vggish_Flatten_flatten_Reshape)
        vggish_Flatten_flatten_Reshape_shape = [vggish_Flatten_flatten_strided_slice,self.vggish_Flatten_flatten_Reshape_shape_1]
        vggish_fc1_fc1_1_Relu = F.relu(vggish_fc1_fc1_1_MatMul)
        vggish_fc1_fc1_2_MatMul = self.vggish_fc1_fc1_2_MatMul(vggish_fc1_fc1_1_Relu)
        vggish_fc1_fc1_2_Relu = F.relu(vggish_fc1_fc1_2_MatMul)
        vggish_fc2_MatMul = self.vggish_fc2_MatMul(vggish_fc1_fc1_2_Relu)
        vggish_fc2_Relu = F.relu(vggish_fc2_MatMul)
        return vggish_fc2_Relu


    @staticmethod
    def __dense(name, **kwargs):
        layer = nn.Linear(**kwargs)
        layer.state_dict()['weight'].copy_(torch.from_numpy(__weights_dict[name]['weights']))
        if 'bias' in __weights_dict[name]:
            layer.state_dict()['bias'].copy_(torch.from_numpy(__weights_dict[name]['bias']))
        return layer

    @staticmethod
    def __conv(dim, name, **kwargs):
        if   dim == 1:  layer = nn.Conv1d(**kwargs)
        elif dim == 2:  layer = nn.Conv2d(**kwargs)
        elif dim == 3:  layer = nn.Conv3d(**kwargs)
        else:           raise NotImplementedError()

        layer.state_dict()['weight'].copy_(torch.from_numpy(__weights_dict[name]['weights']))
        if 'bias' in __weights_dict[name]:
            layer.state_dict()['bias'].copy_(torch.from_numpy(__weights_dict[name]['bias']))
        return layer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants
0