swinir_model_arch.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. # -----------------------------------------------------------------------------------
  2. # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
  3. # Originally Written by Ze Liu, Modified by Jingyun Liang.
  4. # -----------------------------------------------------------------------------------
  5. import math
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torch.utils.checkpoint as checkpoint
  10. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  11. class Mlp(nn.Module):
  12. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  13. super().__init__()
  14. out_features = out_features or in_features
  15. hidden_features = hidden_features or in_features
  16. self.fc1 = nn.Linear(in_features, hidden_features)
  17. self.act = act_layer()
  18. self.fc2 = nn.Linear(hidden_features, out_features)
  19. self.drop = nn.Dropout(drop)
  20. def forward(self, x):
  21. x = self.fc1(x)
  22. x = self.act(x)
  23. x = self.drop(x)
  24. x = self.fc2(x)
  25. x = self.drop(x)
  26. return x
  27. def window_partition(x, window_size):
  28. """
  29. Args:
  30. x: (B, H, W, C)
  31. window_size (int): window size
  32. Returns:
  33. windows: (num_windows*B, window_size, window_size, C)
  34. """
  35. B, H, W, C = x.shape
  36. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  37. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  38. return windows
  39. def window_reverse(windows, window_size, H, W):
  40. """
  41. Args:
  42. windows: (num_windows*B, window_size, window_size, C)
  43. window_size (int): Window size
  44. H (int): Height of image
  45. W (int): Width of image
  46. Returns:
  47. x: (B, H, W, C)
  48. """
  49. B = int(windows.shape[0] / (H * W / window_size / window_size))
  50. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  51. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  52. return x
  53. class WindowAttention(nn.Module):
  54. r""" Window based multi-head self attention (W-MSA) module with relative position bias.
  55. It supports both of shifted and non-shifted window.
  56. Args:
  57. dim (int): Number of input channels.
  58. window_size (tuple[int]): The height and width of the window.
  59. num_heads (int): Number of attention heads.
  60. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  61. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
  62. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  63. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
  64. """
  65. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
  66. super().__init__()
  67. self.dim = dim
  68. self.window_size = window_size # Wh, Ww
  69. self.num_heads = num_heads
  70. head_dim = dim // num_heads
  71. self.scale = qk_scale or head_dim ** -0.5
  72. # define a parameter table of relative position bias
  73. self.relative_position_bias_table = nn.Parameter(
  74. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
  75. # get pair-wise relative position index for each token inside the window
  76. coords_h = torch.arange(self.window_size[0])
  77. coords_w = torch.arange(self.window_size[1])
  78. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
  79. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  80. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  81. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  82. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
  83. relative_coords[:, :, 1] += self.window_size[1] - 1
  84. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  85. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  86. self.register_buffer("relative_position_index", relative_position_index)
  87. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  88. self.attn_drop = nn.Dropout(attn_drop)
  89. self.proj = nn.Linear(dim, dim)
  90. self.proj_drop = nn.Dropout(proj_drop)
  91. trunc_normal_(self.relative_position_bias_table, std=.02)
  92. self.softmax = nn.Softmax(dim=-1)
  93. def forward(self, x, mask=None):
  94. """
  95. Args:
  96. x: input features with shape of (num_windows*B, N, C)
  97. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
  98. """
  99. B_, N, C = x.shape
  100. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  101. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  102. q = q * self.scale
  103. attn = (q @ k.transpose(-2, -1))
  104. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  105. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
  106. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  107. attn = attn + relative_position_bias.unsqueeze(0)
  108. if mask is not None:
  109. nW = mask.shape[0]
  110. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  111. attn = attn.view(-1, self.num_heads, N, N)
  112. attn = self.softmax(attn)
  113. else:
  114. attn = self.softmax(attn)
  115. attn = self.attn_drop(attn)
  116. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  117. x = self.proj(x)
  118. x = self.proj_drop(x)
  119. return x
  120. def extra_repr(self) -> str:
  121. return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
  122. def flops(self, N):
  123. # calculate flops for 1 window with token length of N
  124. flops = 0
  125. # qkv = self.qkv(x)
  126. flops += N * self.dim * 3 * self.dim
  127. # attn = (q @ k.transpose(-2, -1))
  128. flops += self.num_heads * N * (self.dim // self.num_heads) * N
  129. # x = (attn @ v)
  130. flops += self.num_heads * N * N * (self.dim // self.num_heads)
  131. # x = self.proj(x)
  132. flops += N * self.dim * self.dim
  133. return flops
  134. class SwinTransformerBlock(nn.Module):
  135. r""" Swin Transformer Block.
  136. Args:
  137. dim (int): Number of input channels.
  138. input_resolution (tuple[int]): Input resolution.
  139. num_heads (int): Number of attention heads.
  140. window_size (int): Window size.
  141. shift_size (int): Shift size for SW-MSA.
  142. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  143. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  144. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  145. drop (float, optional): Dropout rate. Default: 0.0
  146. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  147. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  148. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
  149. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  150. """
  151. def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
  152. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
  153. act_layer=nn.GELU, norm_layer=nn.LayerNorm):
  154. super().__init__()
  155. self.dim = dim
  156. self.input_resolution = input_resolution
  157. self.num_heads = num_heads
  158. self.window_size = window_size
  159. self.shift_size = shift_size
  160. self.mlp_ratio = mlp_ratio
  161. if min(self.input_resolution) <= self.window_size:
  162. # if window size is larger than input resolution, we don't partition windows
  163. self.shift_size = 0
  164. self.window_size = min(self.input_resolution)
  165. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  166. self.norm1 = norm_layer(dim)
  167. self.attn = WindowAttention(
  168. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
  169. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  170. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  171. self.norm2 = norm_layer(dim)
  172. mlp_hidden_dim = int(dim * mlp_ratio)
  173. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  174. if self.shift_size > 0:
  175. attn_mask = self.calculate_mask(self.input_resolution)
  176. else:
  177. attn_mask = None
  178. self.register_buffer("attn_mask", attn_mask)
  179. def calculate_mask(self, x_size):
  180. # calculate attention mask for SW-MSA
  181. H, W = x_size
  182. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
  183. h_slices = (slice(0, -self.window_size),
  184. slice(-self.window_size, -self.shift_size),
  185. slice(-self.shift_size, None))
  186. w_slices = (slice(0, -self.window_size),
  187. slice(-self.window_size, -self.shift_size),
  188. slice(-self.shift_size, None))
  189. cnt = 0
  190. for h in h_slices:
  191. for w in w_slices:
  192. img_mask[:, h, w, :] = cnt
  193. cnt += 1
  194. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
  195. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  196. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  197. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  198. return attn_mask
  199. def forward(self, x, x_size):
  200. H, W = x_size
  201. B, L, C = x.shape
  202. # assert L == H * W, "input feature has wrong size"
  203. shortcut = x
  204. x = self.norm1(x)
  205. x = x.view(B, H, W, C)
  206. # cyclic shift
  207. if self.shift_size > 0:
  208. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  209. else:
  210. shifted_x = x
  211. # partition windows
  212. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  213. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
  214. # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
  215. if self.input_resolution == x_size:
  216. attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
  217. else:
  218. attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
  219. # merge windows
  220. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
  221. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
  222. # reverse cyclic shift
  223. if self.shift_size > 0:
  224. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  225. else:
  226. x = shifted_x
  227. x = x.view(B, H * W, C)
  228. # FFN
  229. x = shortcut + self.drop_path(x)
  230. x = x + self.drop_path(self.mlp(self.norm2(x)))
  231. return x
  232. def extra_repr(self) -> str:
  233. return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
  234. f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
  235. def flops(self):
  236. flops = 0
  237. H, W = self.input_resolution
  238. # norm1
  239. flops += self.dim * H * W
  240. # W-MSA/SW-MSA
  241. nW = H * W / self.window_size / self.window_size
  242. flops += nW * self.attn.flops(self.window_size * self.window_size)
  243. # mlp
  244. flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
  245. # norm2
  246. flops += self.dim * H * W
  247. return flops
  248. class PatchMerging(nn.Module):
  249. r""" Patch Merging Layer.
  250. Args:
  251. input_resolution (tuple[int]): Resolution of input feature.
  252. dim (int): Number of input channels.
  253. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  254. """
  255. def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
  256. super().__init__()
  257. self.input_resolution = input_resolution
  258. self.dim = dim
  259. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  260. self.norm = norm_layer(4 * dim)
  261. def forward(self, x):
  262. """
  263. x: B, H*W, C
  264. """
  265. H, W = self.input_resolution
  266. B, L, C = x.shape
  267. assert L == H * W, "input feature has wrong size"
  268. assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
  269. x = x.view(B, H, W, C)
  270. x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
  271. x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
  272. x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
  273. x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
  274. x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
  275. x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
  276. x = self.norm(x)
  277. x = self.reduction(x)
  278. return x
  279. def extra_repr(self) -> str:
  280. return f"input_resolution={self.input_resolution}, dim={self.dim}"
  281. def flops(self):
  282. H, W = self.input_resolution
  283. flops = H * W * self.dim
  284. flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
  285. return flops
  286. class BasicLayer(nn.Module):
  287. """ A basic Swin Transformer layer for one stage.
  288. Args:
  289. dim (int): Number of input channels.
  290. input_resolution (tuple[int]): Input resolution.
  291. depth (int): Number of blocks.
  292. num_heads (int): Number of attention heads.
  293. window_size (int): Local window size.
  294. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  295. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  296. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  297. drop (float, optional): Dropout rate. Default: 0.0
  298. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  299. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  300. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  301. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  302. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  303. """
  304. def __init__(self, dim, input_resolution, depth, num_heads, window_size,
  305. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
  306. drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
  307. super().__init__()
  308. self.dim = dim
  309. self.input_resolution = input_resolution
  310. self.depth = depth
  311. self.use_checkpoint = use_checkpoint
  312. # build blocks
  313. self.blocks = nn.ModuleList([
  314. SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
  315. num_heads=num_heads, window_size=window_size,
  316. shift_size=0 if (i % 2 == 0) else window_size // 2,
  317. mlp_ratio=mlp_ratio,
  318. qkv_bias=qkv_bias, qk_scale=qk_scale,
  319. drop=drop, attn_drop=attn_drop,
  320. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  321. norm_layer=norm_layer)
  322. for i in range(depth)])
  323. # patch merging layer
  324. if downsample is not None:
  325. self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
  326. else:
  327. self.downsample = None
  328. def forward(self, x, x_size):
  329. for blk in self.blocks:
  330. if self.use_checkpoint:
  331. x = checkpoint.checkpoint(blk, x, x_size)
  332. else:
  333. x = blk(x, x_size)
  334. if self.downsample is not None:
  335. x = self.downsample(x)
  336. return x
  337. def extra_repr(self) -> str:
  338. return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
  339. def flops(self):
  340. flops = 0
  341. for blk in self.blocks:
  342. flops += blk.flops()
  343. if self.downsample is not None:
  344. flops += self.downsample.flops()
  345. return flops
  346. class RSTB(nn.Module):
  347. """Residual Swin Transformer Block (RSTB).
  348. Args:
  349. dim (int): Number of input channels.
  350. input_resolution (tuple[int]): Input resolution.
  351. depth (int): Number of blocks.
  352. num_heads (int): Number of attention heads.
  353. window_size (int): Local window size.
  354. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  355. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  356. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
  357. drop (float, optional): Dropout rate. Default: 0.0
  358. attn_drop (float, optional): Attention dropout rate. Default: 0.0
  359. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  360. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
  361. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  362. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
  363. img_size: Input image size.
  364. patch_size: Patch size.
  365. resi_connection: The convolutional block before residual connection.
  366. """
  367. def __init__(self, dim, input_resolution, depth, num_heads, window_size,
  368. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
  369. drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
  370. img_size=224, patch_size=4, resi_connection='1conv'):
  371. super(RSTB, self).__init__()
  372. self.dim = dim
  373. self.input_resolution = input_resolution
  374. self.residual_group = BasicLayer(dim=dim,
  375. input_resolution=input_resolution,
  376. depth=depth,
  377. num_heads=num_heads,
  378. window_size=window_size,
  379. mlp_ratio=mlp_ratio,
  380. qkv_bias=qkv_bias, qk_scale=qk_scale,
  381. drop=drop, attn_drop=attn_drop,
  382. drop_path=drop_path,
  383. norm_layer=norm_layer,
  384. downsample=downsample,
  385. use_checkpoint=use_checkpoint)
  386. if resi_connection == '1conv':
  387. self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
  388. elif resi_connection == '3conv':
  389. # to save parameters and memory
  390. self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
  391. nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
  392. nn.LeakyReLU(negative_slope=0.2, inplace=True),
  393. nn.Conv2d(dim // 4, dim, 3, 1, 1))
  394. self.patch_embed = PatchEmbed(
  395. img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
  396. norm_layer=None)
  397. self.patch_unembed = PatchUnEmbed(
  398. img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
  399. norm_layer=None)
  400. def forward(self, x, x_size):
  401. return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
  402. def flops(self):
  403. flops = 0
  404. flops += self.residual_group.flops()
  405. H, W = self.input_resolution
  406. flops += H * W * self.dim * self.dim * 9
  407. flops += self.patch_embed.flops()
  408. flops += self.patch_unembed.flops()
  409. return flops
  410. class PatchEmbed(nn.Module):
  411. r""" Image to Patch Embedding
  412. Args:
  413. img_size (int): Image size. Default: 224.
  414. patch_size (int): Patch token size. Default: 4.
  415. in_chans (int): Number of input image channels. Default: 3.
  416. embed_dim (int): Number of linear projection output channels. Default: 96.
  417. norm_layer (nn.Module, optional): Normalization layer. Default: None
  418. """
  419. def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
  420. super().__init__()
  421. img_size = to_2tuple(img_size)
  422. patch_size = to_2tuple(patch_size)
  423. patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
  424. self.img_size = img_size
  425. self.patch_size = patch_size
  426. self.patches_resolution = patches_resolution
  427. self.num_patches = patches_resolution[0] * patches_resolution[1]
  428. self.in_chans = in_chans
  429. self.embed_dim = embed_dim
  430. if norm_layer is not None:
  431. self.norm = norm_layer(embed_dim)
  432. else:
  433. self.norm = None
  434. def forward(self, x):
  435. x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
  436. if self.norm is not None:
  437. x = self.norm(x)
  438. return x
  439. def flops(self):
  440. flops = 0
  441. H, W = self.img_size
  442. if self.norm is not None:
  443. flops += H * W * self.embed_dim
  444. return flops
  445. class PatchUnEmbed(nn.Module):
  446. r""" Image to Patch Unembedding
  447. Args:
  448. img_size (int): Image size. Default: 224.
  449. patch_size (int): Patch token size. Default: 4.
  450. in_chans (int): Number of input image channels. Default: 3.
  451. embed_dim (int): Number of linear projection output channels. Default: 96.
  452. norm_layer (nn.Module, optional): Normalization layer. Default: None
  453. """
  454. def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
  455. super().__init__()
  456. img_size = to_2tuple(img_size)
  457. patch_size = to_2tuple(patch_size)
  458. patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
  459. self.img_size = img_size
  460. self.patch_size = patch_size
  461. self.patches_resolution = patches_resolution
  462. self.num_patches = patches_resolution[0] * patches_resolution[1]
  463. self.in_chans = in_chans
  464. self.embed_dim = embed_dim
  465. def forward(self, x, x_size):
  466. B, HW, C = x.shape
  467. x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
  468. return x
  469. def flops(self):
  470. flops = 0
  471. return flops
  472. class Upsample(nn.Sequential):
  473. """Upsample module.
  474. Args:
  475. scale (int): Scale factor. Supported scales: 2^n and 3.
  476. num_feat (int): Channel number of intermediate features.
  477. """
  478. def __init__(self, scale, num_feat):
  479. m = []
  480. if (scale & (scale - 1)) == 0: # scale = 2^n
  481. for _ in range(int(math.log(scale, 2))):
  482. m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
  483. m.append(nn.PixelShuffle(2))
  484. elif scale == 3:
  485. m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
  486. m.append(nn.PixelShuffle(3))
  487. else:
  488. raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
  489. super(Upsample, self).__init__(*m)
  490. class UpsampleOneStep(nn.Sequential):
  491. """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
  492. Used in lightweight SR to save parameters.
  493. Args:
  494. scale (int): Scale factor. Supported scales: 2^n and 3.
  495. num_feat (int): Channel number of intermediate features.
  496. """
  497. def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
  498. self.num_feat = num_feat
  499. self.input_resolution = input_resolution
  500. m = []
  501. m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
  502. m.append(nn.PixelShuffle(scale))
  503. super(UpsampleOneStep, self).__init__(*m)
  504. def flops(self):
  505. H, W = self.input_resolution
  506. flops = H * W * self.num_feat * 3 * 9
  507. return flops
  508. class SwinIR(nn.Module):
  509. r""" SwinIR
  510. A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
  511. Args:
  512. img_size (int | tuple(int)): Input image size. Default 64
  513. patch_size (int | tuple(int)): Patch size. Default: 1
  514. in_chans (int): Number of input image channels. Default: 3
  515. embed_dim (int): Patch embedding dimension. Default: 96
  516. depths (tuple(int)): Depth of each Swin Transformer layer.
  517. num_heads (tuple(int)): Number of attention heads in different layers.
  518. window_size (int): Window size. Default: 7
  519. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
  520. qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
  521. qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
  522. drop_rate (float): Dropout rate. Default: 0
  523. attn_drop_rate (float): Attention dropout rate. Default: 0
  524. drop_path_rate (float): Stochastic depth rate. Default: 0.1
  525. norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
  526. ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
  527. patch_norm (bool): If True, add normalization after patch embedding. Default: True
  528. use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
  529. upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
  530. img_range: Image range. 1. or 255.
  531. upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
  532. resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
  533. """
  534. def __init__(self, img_size=64, patch_size=1, in_chans=3,
  535. embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
  536. window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
  537. drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
  538. norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
  539. use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
  540. **kwargs):
  541. super(SwinIR, self).__init__()
  542. num_in_ch = in_chans
  543. num_out_ch = in_chans
  544. num_feat = 64
  545. self.img_range = img_range
  546. if in_chans == 3:
  547. rgb_mean = (0.4488, 0.4371, 0.4040)
  548. self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
  549. else:
  550. self.mean = torch.zeros(1, 1, 1, 1)
  551. self.upscale = upscale
  552. self.upsampler = upsampler
  553. self.window_size = window_size
  554. #####################################################################################################
  555. ################################### 1, shallow feature extraction ###################################
  556. self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
  557. #####################################################################################################
  558. ################################### 2, deep feature extraction ######################################
  559. self.num_layers = len(depths)
  560. self.embed_dim = embed_dim
  561. self.ape = ape
  562. self.patch_norm = patch_norm
  563. self.num_features = embed_dim
  564. self.mlp_ratio = mlp_ratio
  565. # split image into non-overlapping patches
  566. self.patch_embed = PatchEmbed(
  567. img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
  568. norm_layer=norm_layer if self.patch_norm else None)
  569. num_patches = self.patch_embed.num_patches
  570. patches_resolution = self.patch_embed.patches_resolution
  571. self.patches_resolution = patches_resolution
  572. # merge non-overlapping patches into image
  573. self.patch_unembed = PatchUnEmbed(
  574. img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
  575. norm_layer=norm_layer if self.patch_norm else None)
  576. # absolute position embedding
  577. if self.ape:
  578. self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
  579. trunc_normal_(self.absolute_pos_embed, std=.02)
  580. self.pos_drop = nn.Dropout(p=drop_rate)
  581. # stochastic depth
  582. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
  583. # build Residual Swin Transformer blocks (RSTB)
  584. self.layers = nn.ModuleList()
  585. for i_layer in range(self.num_layers):
  586. layer = RSTB(dim=embed_dim,
  587. input_resolution=(patches_resolution[0],
  588. patches_resolution[1]),
  589. depth=depths[i_layer],
  590. num_heads=num_heads[i_layer],
  591. window_size=window_size,
  592. mlp_ratio=self.mlp_ratio,
  593. qkv_bias=qkv_bias, qk_scale=qk_scale,
  594. drop=drop_rate, attn_drop=attn_drop_rate,
  595. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
  596. norm_layer=norm_layer,
  597. downsample=None,
  598. use_checkpoint=use_checkpoint,
  599. img_size=img_size,
  600. patch_size=patch_size,
  601. resi_connection=resi_connection
  602. )
  603. self.layers.append(layer)
  604. self.norm = norm_layer(self.num_features)
  605. # build the last conv layer in deep feature extraction
  606. if resi_connection == '1conv':
  607. self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
  608. elif resi_connection == '3conv':
  609. # to save parameters and memory
  610. self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
  611. nn.LeakyReLU(negative_slope=0.2, inplace=True),
  612. nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
  613. nn.LeakyReLU(negative_slope=0.2, inplace=True),
  614. nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
  615. #####################################################################################################
  616. ################################ 3, high quality image reconstruction ################################
  617. if self.upsampler == 'pixelshuffle':
  618. # for classical SR
  619. self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
  620. nn.LeakyReLU(inplace=True))
  621. self.upsample = Upsample(upscale, num_feat)
  622. self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
  623. elif self.upsampler == 'pixelshuffledirect':
  624. # for lightweight SR (to save parameters)
  625. self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
  626. (patches_resolution[0], patches_resolution[1]))
  627. elif self.upsampler == 'nearest+conv':
  628. # for real-world SR (less artifacts)
  629. self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
  630. nn.LeakyReLU(inplace=True))
  631. self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
  632. if self.upscale == 4:
  633. self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
  634. self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
  635. self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
  636. self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
  637. else:
  638. # for image denoising and JPEG compression artifact reduction
  639. self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
  640. self.apply(self._init_weights)
  641. def _init_weights(self, m):
  642. if isinstance(m, nn.Linear):
  643. trunc_normal_(m.weight, std=.02)
  644. if isinstance(m, nn.Linear) and m.bias is not None:
  645. nn.init.constant_(m.bias, 0)
  646. elif isinstance(m, nn.LayerNorm):
  647. nn.init.constant_(m.bias, 0)
  648. nn.init.constant_(m.weight, 1.0)
  649. @torch.jit.ignore
  650. def no_weight_decay(self):
  651. return {'absolute_pos_embed'}
  652. @torch.jit.ignore
  653. def no_weight_decay_keywords(self):
  654. return {'relative_position_bias_table'}
  655. def check_image_size(self, x):
  656. _, _, h, w = x.size()
  657. mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
  658. mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
  659. x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
  660. return x
  661. def forward_features(self, x):
  662. x_size = (x.shape[2], x.shape[3])
  663. x = self.patch_embed(x)
  664. if self.ape:
  665. x = x + self.absolute_pos_embed
  666. x = self.pos_drop(x)
  667. for layer in self.layers:
  668. x = layer(x, x_size)
  669. x = self.norm(x) # B L C
  670. x = self.patch_unembed(x, x_size)
  671. return x
  672. def forward(self, x):
  673. H, W = x.shape[2:]
  674. x = self.check_image_size(x)
  675. self.mean = self.mean.type_as(x)
  676. x = (x - self.mean) * self.img_range
  677. if self.upsampler == 'pixelshuffle':
  678. # for classical SR
  679. x = self.conv_first(x)
  680. x = self.conv_after_body(self.forward_features(x)) + x
  681. x = self.conv_before_upsample(x)
  682. x = self.conv_last(self.upsample(x))
  683. elif self.upsampler == 'pixelshuffledirect':
  684. # for lightweight SR
  685. x = self.conv_first(x)
  686. x = self.conv_after_body(self.forward_features(x)) + x
  687. x = self.upsample(x)
  688. elif self.upsampler == 'nearest+conv':
  689. # for real-world SR
  690. x = self.conv_first(x)
  691. x = self.conv_after_body(self.forward_features(x)) + x
  692. x = self.conv_before_upsample(x)
  693. x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
  694. if self.upscale == 4:
  695. x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
  696. x = self.conv_last(self.lrelu(self.conv_hr(x)))
  697. else:
  698. # for image denoising and JPEG compression artifact reduction
  699. x_first = self.conv_first(x)
  700. res = self.conv_after_body(self.forward_features(x_first)) + x_first
  701. x = x + self.conv_last(res)
  702. x = x / self.img_range + self.mean
  703. return x[:, :, :H*self.upscale, :W*self.upscale]
  704. def flops(self):
  705. flops = 0
  706. H, W = self.patches_resolution
  707. flops += H * W * 3 * self.embed_dim * 9
  708. flops += self.patch_embed.flops()
  709. for layer in self.layers:
  710. flops += layer.flops()
  711. flops += H * W * 3 * self.embed_dim * self.embed_dim
  712. flops += self.upsample.flops()
  713. return flops
  714. if __name__ == '__main__':
  715. upscale = 4
  716. window_size = 8
  717. height = (1024 // upscale // window_size + 1) * window_size
  718. width = (720 // upscale // window_size + 1) * window_size
  719. model = SwinIR(upscale=2, img_size=(height, width),
  720. window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
  721. embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
  722. print(model)
  723. print(height, width, model.flops() / 1e9)
  724. x = torch.randn((1, 3, height, width))
  725. x = model(x)
  726. print(x.shape)