|
@@ -74,12 +74,9 @@ class WindowAttention(nn.Module):
|
|
"""
|
|
"""
|
|
|
|
|
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
|
- pretrained_window_size=None):
|
|
|
|
|
|
+ pretrained_window_size=(0, 0)):
|
|
|
|
|
|
super().__init__()
|
|
super().__init__()
|
|
-
|
|
|
|
- pretrained_window_size = pretrained_window_size or [0, 0]
|
|
|
|
-
|
|
|
|
self.dim = dim
|
|
self.dim = dim
|
|
self.window_size = window_size # Wh, Ww
|
|
self.window_size = window_size # Wh, Ww
|
|
self.pretrained_window_size = pretrained_window_size
|
|
self.pretrained_window_size = pretrained_window_size
|
|
@@ -701,17 +698,13 @@ class Swin2SR(nn.Module):
|
|
"""
|
|
"""
|
|
|
|
|
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
|
- embed_dim=96, depths=None, num_heads=None,
|
|
|
|
|
|
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
|
window_size=7, mlp_ratio=4., qkv_bias=True,
|
|
window_size=7, mlp_ratio=4., qkv_bias=True,
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
|
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
|
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
|
**kwargs):
|
|
**kwargs):
|
|
super(Swin2SR, self).__init__()
|
|
super(Swin2SR, self).__init__()
|
|
-
|
|
|
|
- depths = depths or [6, 6, 6, 6]
|
|
|
|
- num_heads = num_heads or [6, 6, 6, 6]
|
|
|
|
-
|
|
|
|
num_in_ch = in_chans
|
|
num_in_ch = in_chans
|
|
num_out_ch = in_chans
|
|
num_out_ch = in_chans
|
|
num_feat = 64
|
|
num_feat = 64
|