Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generating a model without functions? #1935

Open
noahcoolboy opened this issue Nov 7, 2024 · 0 comments
Open

Generating a model without functions? #1935

noahcoolboy opened this issue Nov 7, 2024 · 0 comments

Comments

@noahcoolboy
Copy link

Hello! I've been trying to port a model from pytorch manually to onnx using onnxscript.
I've tried to come with a way of doing this elegantly by creating "custom blocks" with attributes.
However, because of how onnxscript currently works, there are some issues.

This is my current code

def GConv2D(key: str, kernel_size: int, padding: int):
    weight = weights[key + ".weight"].numpy()
    bias = weights[key + ".bias"].numpy()

    @script()
    def GConv2D(r: FLOAT[...]):
        r = op.Conv(
            r,
            weight,
            bias,
            kernel_shape=[kernel_size, kernel_size],
            pads=[padding, padding, padding, padding],
        )

        return r

    return GConv2D

def GroupResBlock(key: str, in_dim: int, out_dim: int):
    downsample = GConv2D(key + ".downsample", 1, 0) if in_dim != out_dim else Identity()
    conv1 = GConv2D(key + ".conv1", 3, 1)
    conv2 = GConv2D(key + ".conv2", 3, 1)

    @script()
    def GroupResBlock(x: FLOAT[...]):
        x = conv1(op.Relu(x))
        x = conv2(op.Relu(x))
        x = downsample(x)
        return x
    
    return GroupResBlock

def MaskDecoderBlock(key: str):
    up_16_8 = GroupResBlock(key + ".up_16_8.out_conv", 256, 128)
    up_8_4 = GroupResBlock(key + ".up_8_4.out_conv", 128, 128)

    @script()
    def MaskDecoderBlock(x: FLOAT[...]):
        x = up_16_8(x)
        x = up_8_4(x)
        return x

    return MaskDecoderBlock

model = MaskDecoderBlock("mask_decoder").to_model_proto()

"downsample" from GroupResBlock is set conditionally. I want it to downsample if in_dim and out_dim are not equal to each other. To avoid having to put this if statement in the model itself, the check is done before so it can be baked into the model as is.

The issue is, up_16_8 gets created, and the function GroupResBlock gets defined as having the downsample block. When up_8_4 gets created, the function GroupResBlock is already defined and it reuses it (with the downsample block, and the wrong weights!)

Is there a way to generate a model proto without functions? As to make it avoid reusing blocks, and make it generate a flat graph instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant