You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Originally posted by Arbiter-glitch May 22, 2024
Iam a Research student working on image processing. While using FINN, I was expecting the accuracy to be maintained between the trained brevitas model and the finn onnx model after streamlining. But even though I am getting the output there is a considerable drop in the image quality metrics. Is it an expected drop in accuracy or is it supposed to be the same. My files have been posted in earlier closed issues. here and here
[Update:] I have found that the ouput tensor varies from the first step: qonnx to finn conversion onwards. I dont get the same image quality from onnx execution like how it is in the quantized brevitas software model. Is this normal? Iam having a 5dB drop in PSNR between software and Finn output.
Is it becuase of the quantizations I am using in brevitas are not compatible with FINN?: I took this model file from brevitas examples and modified it accordingly.
As a separate case I also tried using bias quant, since I thought maybe that is needed for accurate output in FINN flow. But, when I exported it to FINN and did the onnx execution, the output were all zeroes, or they were gradually, through the layers becoming zeros. My model files are below. Maybe weights are not accurate after foldquantweights()??
Model.py
import torch
import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8Bias
from brevitas.nn.quant_layer import WeightQuantType
from common import CommonIntAccumulatorAwareWeightQuant
from common import CommonIntWeightPerChannelQuant
from common import CommonIntActQuant
from common import CommonUintActQuant
from common import ConstUint8ActQuant
from common import QuantNearestNeighborConvolution
from torch import nn
import math
IO_DATA_BIT_WIDTH = 8
IO_ACC_BIT_WIDTH = 32
class FSRCNN(nn.Module):
def __init__(self, scale_factor, num_channels=1, d=16, s=12, m=4,weight_bit_width: int = 8,
act_bit_width: int = 8,
acc_bit_width: int = 32,
weight_quant:WeightQuantType = CommonIntAccumulatorAwareWeightQuant):
super(FSRCNN, self).__init__()
self.first_part = nn.Sequential(
qnn.QuantConv2d(num_channels, d, kernel_size=5, padding=5 // 2, input_quant=CommonIntActQuant,
input_bit_width=12,
weight_bit_width=weight_bit_width,
weight_accumulator_bit_width=acc_bit_width,
weight_quant=weight_quant,
bias=True),
nn.ReLU(inplace=True)
)
self.mid_part = [qnn.QuantConv2d(d, s, kernel_size=1,input_quant=CommonUintActQuant,
input_bit_width=act_bit_width,
weight_bit_width=weight_bit_width,
weight_accumulator_bit_width=acc_bit_width,
weight_quant=weight_quant,
bias=True),
nn.ReLU(inplace=True)]
for _ in range(m):
self.mid_part.extend([qnn.QuantConv2d(s, s, kernel_size=3, padding=3 // 2, input_quant=CommonUintActQuant,
input_bit_width=act_bit_width,
weight_bit_width=weight_bit_width,
weight_accumulator_bit_width=acc_bit_width,
weight_quant=weight_quant,
bias=True),
nn.ReLU(inplace=True)])
self.mid_part.extend([qnn.QuantConv2d(s, d, kernel_size=1,input_quant=CommonUintActQuant,
input_bit_width=act_bit_width,
weight_bit_width=weight_bit_width,
weight_accumulator_bit_width=acc_bit_width,
weight_quant=weight_quant,
bias=True),
nn.ReLU(inplace=True)])
self.mid_part = nn.Sequential(*self.mid_part)
self.upsample=QuantNearestNeighborConvolution(d,1,kernel_size=3,stride=1,padding=3//2,upscale_factor=scale_factor)
self.relu = nn.ReLU(inplace=True)
# Using a QuantReLU here because we need to read out a uint8 image, but FINN
# requires a ReLU node to precede an unsigned int quant node
#self.out = qnn.QuantReLU(act_quant=ConstUint8ActQuant, bit_width=IO_DATA_BIT_WIDTH)
self.out=qnn.QuantIdentity(act_quant=ConstUint8ActQuant, return_quant_tensor=False, bit_width=8)
def forward(self, x):
x = self.first_part(x)
x = self.mid_part(x)
x = self.upsample(x)
x = self.relu(x)
x = self.out(x)
return x
Common.py
from typing import Optional
from torch import Tensor
import torch.nn as nn
from brevitas.core.restrict_val import FloatRestrictValue
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
import brevitas.nn as qnn
from brevitas.nn.quant_layer import WeightQuantType
from brevitas.quant import Int8AccumulatorAwareWeightQuant
from brevitas.quant import Int8AccumulatorAwareZeroCenterWeightQuant
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import Uint8ActPerTensorFloat
class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):
"""
Common per-channel weight quantizer with bit-width set to None so that it's forced to be
specified by each layer.
"""
scaling_per_output_channel = True
class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant):
"""A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance"""
restrict_scaling_impl = FloatRestrictValue # backwards compatibility
bit_width = None
class CommonIntAccumulatorAwareZeroCenterWeightQuant(Int8AccumulatorAwareZeroCenterWeightQuant):
"""A2Q+: Improving Accumulator-Aware Weight Quantization"""
bit_width = None
class CommonIntActQuant(Int8ActPerTensorFloat):
"""
Common signed act quantizer with bit-width set to None so that it's forced to be specified by
each layer.
"""
bit_width = None
restrict_scaling_type = RestrictValueType.LOG_FP
class CommonUintActQuant(Uint8ActPerTensorFloat):
"""Common unsigned act quantizer with bit-width set to None so that it's forced to be
specified by each layer"""
bit_width = None
restrict_scaling_type = RestrictValueType.LOG_FP
class ConstUint8ActQuant(CommonUintActQuant):
"""8-bit unsigned integer activation quantizer with constant unit scaling factor, used
by the models to quantize outputs into the image space"""
scaling_impl_type = ScalingImplType.CONST
scaling_init = 1.
class QuantNearestNeighborConvolution(nn.Module):
"""Quantized nearest neighbor resize convolution"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Optional[int] = 5,
stride: Optional[int] = 1,
padding: Optional[int] = 0,
upscale_factor: Optional[int] = 2,
signed_act: Optional[bool] = False,
bias: Optional[bool] = True,
weight_quant: WeightQuantType = CommonIntWeightPerChannelQuant,
acc_bit_width: Optional[int] = 32,
act_bit_width: Optional[int] = 8,
weight_bit_width: Optional[int] = 8):
super().__init__()
# Using unsigned int activation quantization if the preceding layer has
# a non-negative range (e.g., following a ReLU activation function)
act_quant = CommonIntActQuant if signed_act else CommonUintActQuant
self.upscale_factor = upscale_factor
# Need to have the quantization node before the nearest neighbor upsampling node
# for FINN compatibility since the FINN compiler will streamline the quantization
# node with the preceding monotonic activation function. In the case of ESPCN, this
# is a ReLU. We need to return the QuantTensor though so that the conv2d is aware
# of the input bit-width for accumulator-aware quantization (A2Q). For more discussion
# on this, see https://arxiv.org/abs/2301.13376.
self.input_quant = qnn.QuantIdentity(
act_quant=act_quant, return_quant_tensor=True, bit_width=act_bit_width)
self.interp = qnn.QuantUpsample(scale_factor=upscale_factor)
self.conv = qnn.QuantConv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride,
bias=bias,
input_quant=None,
weight_accumulator_bit_width=acc_bit_width,
weight_bit_width=weight_bit_width,
weight_quant=weight_quant)
def forward(self, inp: Tensor) -> Tensor:
return self.conv(self.interp(self.input_quant(inp)))
The text was updated successfully, but these errors were encountered:
Discussed in #1086
Originally posted by Arbiter-glitch May 22, 2024
Iam a Research student working on image processing. While using FINN, I was expecting the accuracy to be maintained between the trained brevitas model and the finn onnx model after streamlining. But even though I am getting the output there is a considerable drop in the image quality metrics. Is it an expected drop in accuracy or is it supposed to be the same. My files have been posted in earlier closed issues. here and here
[Update:] I have found that the ouput tensor varies from the first step: qonnx to finn conversion onwards. I dont get the same image quality from onnx execution like how it is in the quantized brevitas software model. Is this normal? Iam having a 5dB drop in PSNR between software and Finn output.
Is it becuase of the quantizations I am using in brevitas are not compatible with FINN?: I took this model file from brevitas examples and modified it accordingly.
As a separate case I also tried using bias quant, since I thought maybe that is needed for accurate output in FINN flow. But, when I exported it to FINN and did the onnx execution, the output were all zeroes, or they were gradually, through the layers becoming zeros. My model files are below. Maybe weights are not accurate after foldquantweights()??
Model.py
Common.py
The text was updated successfully, but these errors were encountered: