
class Swish(nn.Module): 

def __init__(self): 

super(Swish, self).__init__() 

self.sigmoid = nn.Sigmoid() 

def forward(self, x): 

return x * self.sigmoid(x) 



def _DropPath(x, drop_prob, training): 

if drop_prob > 0 and training: 

keep_prob = 1 – drop_prob 

if x.is_cuda: 

mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 

else: 

mask = Variable(torch.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 

x.div_(keep_prob) 

x.mul_(mask) 

return x 



def _LayerNorm(channels): 

return nn.GroupNorm(channels, channels) #as groups = number of channels, group norm. is layer norm. 



def _Conv3x3Bn(in_channels, out_channels, stride): 

return nn.Sequential( 

OptConv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False, perfect_gradient=True), 

_LayerNorm(out_channels), 

Swish() 

) 



def _Conv1x1Bn(in_channels, out_channels): 

return nn.Sequential( 

OptConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False, perfect_gradient=True), 

_LayerNorm(out_channels), 

Swish() 

) 



class SqueezeAndExcite(nn.Module): 

def __init__(self, channels, squeeze_channels, se_ratio): 

super(SqueezeAndExcite, self).__init__() 



squeeze_channels = squeeze_channels * se_ratio 

if not squeeze_channels.is_integer(): 

raise ValueError(‘channels must be divisible by 1/ratio’) 



squeeze_channels = int(squeeze_channels) 

self.se_reduce = OptConv2d(channels, squeeze_channels, kernel_size=1, stride=1, padding=0, bias=True, perfect_gradient=True) 

self.non_linear1 = Swish() 

self.se_expand = OptConv2d(squeeze_channels, channels, kernel_size=1, stride=1, padding=0, bias=True, perfect_gradient=True) 

self.non_linear2 = nn.Sigmoid() 



def forward(self, x): 

y = torch.mean(x, (2, 3), keepdim=True) 

y = self.non_linear1(self.se_reduce(y)) 

y = self.non_linear2(self.se_expand(y)) 

y = x * y 



return y 



class MBConvBlock(nn.Module): 

def __init__(self, in_channels, out_channels, kernel_size, stride, expand_ratio, se_ratio, drop_path_rate): 

super(MBConvBlock, self).__init__() 



expand = (expand_ratio != 1) 

expand_channels = in_channels * expand_ratio 

se = (se_ratio != 0.0) 

self.residual_connection = (stride == 1 and in_channels == out_channels) 

self.drop_path_rate = drop_path_rate 



conv = [] 



if expand: 

# expansion phase 

pw_expansion = nn.Sequential( 

OptConv2d(in_channels, expand_channels, kernel_size=1, stride=1, padding=0, bias=False, perfect_gradient=True), 

_LayerNorm(expand_channels), 

Swish() 

) 

conv.append(pw_expansion) 



# depthwise convolution phase 

dw = nn.Sequential(OptConv2d(expand_channels, 

expand_channels, 

kernel_size=kernel_size, 

stride=stride, 

padding=kernel_size//2, 

bias=False, 

perfect_gradient=True, 

groups=expand_channels 

), 

_LayerNorm(expand_channels), 

Swish() 

) 

conv.append(dw) 



if se: 

squeeze_excite = SqueezeAndExcite(expand_channels, in_channels, se_ratio) 

conv.append(squeeze_excite) 



# projection phase 

pw_projection = nn.Sequential( 

OptConv2d(expand_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False, perfect_gradient=True), 

_LayerNorm(out_channels) 

) 

conv.append(pw_projection) 

self.conv = nn.Sequential(*conv) 



def forward(self, x): 

if self.residual_connection: 

return x + _DropPath(self.conv(x), self.drop_path_rate, self.training) 

else: 

return self.conv(x) 