scunet_model_arch.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. from einops import rearrange
  6. from einops.layers.torch import Rearrange
  7. from timm.models.layers import trunc_normal_, DropPath
  8. class WMSA(nn.Module):
  9. """ Self-attention module in Swin Transformer
  10. """
  11. def __init__(self, input_dim, output_dim, head_dim, window_size, type):
  12. super(WMSA, self).__init__()
  13. self.input_dim = input_dim
  14. self.output_dim = output_dim
  15. self.head_dim = head_dim
  16. self.scale = self.head_dim ** -0.5
  17. self.n_heads = input_dim // head_dim
  18. self.window_size = window_size
  19. self.type = type
  20. self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
  21. self.relative_position_params = nn.Parameter(
  22. torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
  23. self.linear = nn.Linear(self.input_dim, self.output_dim)
  24. trunc_normal_(self.relative_position_params, std=.02)
  25. self.relative_position_params = torch.nn.Parameter(
  26. self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
  27. 2).transpose(
  28. 0, 1))
  29. def generate_mask(self, h, w, p, shift):
  30. """ generating the mask of SW-MSA
  31. Args:
  32. shift: shift parameters in CyclicShift.
  33. Returns:
  34. attn_mask: should be (1 1 w p p),
  35. """
  36. # supporting square.
  37. attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
  38. if self.type == 'W':
  39. return attn_mask
  40. s = p - shift
  41. attn_mask[-1, :, :s, :, s:, :] = True
  42. attn_mask[-1, :, s:, :, :s, :] = True
  43. attn_mask[:, -1, :, :s, :, s:] = True
  44. attn_mask[:, -1, :, s:, :, :s] = True
  45. attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
  46. return attn_mask
  47. def forward(self, x):
  48. """ Forward pass of Window Multi-head Self-attention module.
  49. Args:
  50. x: input tensor with shape of [b h w c];
  51. attn_mask: attention mask, fill -inf where the value is True;
  52. Returns:
  53. output: tensor shape [b h w c]
  54. """
  55. if self.type != 'W':
  56. x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
  57. x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
  58. h_windows = x.size(1)
  59. w_windows = x.size(2)
  60. # square validation
  61. # assert h_windows == w_windows
  62. x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
  63. qkv = self.embedding_layer(x)
  64. q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
  65. sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
  66. # Adding learnable relative embedding
  67. sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
  68. # Using Attn Mask to distinguish different subwindows.
  69. if self.type != 'W':
  70. attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
  71. sim = sim.masked_fill_(attn_mask, float("-inf"))
  72. probs = nn.functional.softmax(sim, dim=-1)
  73. output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
  74. output = rearrange(output, 'h b w p c -> b w p (h c)')
  75. output = self.linear(output)
  76. output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
  77. if self.type != 'W':
  78. output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
  79. return output
  80. def relative_embedding(self):
  81. cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
  82. relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
  83. # negative is allowed
  84. return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
  85. class Block(nn.Module):
  86. def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
  87. """ SwinTransformer Block
  88. """
  89. super(Block, self).__init__()
  90. self.input_dim = input_dim
  91. self.output_dim = output_dim
  92. assert type in ['W', 'SW']
  93. self.type = type
  94. if input_resolution <= window_size:
  95. self.type = 'W'
  96. self.ln1 = nn.LayerNorm(input_dim)
  97. self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
  98. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  99. self.ln2 = nn.LayerNorm(input_dim)
  100. self.mlp = nn.Sequential(
  101. nn.Linear(input_dim, 4 * input_dim),
  102. nn.GELU(),
  103. nn.Linear(4 * input_dim, output_dim),
  104. )
  105. def forward(self, x):
  106. x = x + self.drop_path(self.msa(self.ln1(x)))
  107. x = x + self.drop_path(self.mlp(self.ln2(x)))
  108. return x
  109. class ConvTransBlock(nn.Module):
  110. def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
  111. """ SwinTransformer and Conv Block
  112. """
  113. super(ConvTransBlock, self).__init__()
  114. self.conv_dim = conv_dim
  115. self.trans_dim = trans_dim
  116. self.head_dim = head_dim
  117. self.window_size = window_size
  118. self.drop_path = drop_path
  119. self.type = type
  120. self.input_resolution = input_resolution
  121. assert self.type in ['W', 'SW']
  122. if self.input_resolution <= self.window_size:
  123. self.type = 'W'
  124. self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
  125. self.type, self.input_resolution)
  126. self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
  127. self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
  128. self.conv_block = nn.Sequential(
  129. nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
  130. nn.ReLU(True),
  131. nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
  132. )
  133. def forward(self, x):
  134. conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
  135. conv_x = self.conv_block(conv_x) + conv_x
  136. trans_x = Rearrange('b c h w -> b h w c')(trans_x)
  137. trans_x = self.trans_block(trans_x)
  138. trans_x = Rearrange('b h w c -> b c h w')(trans_x)
  139. res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
  140. x = x + res
  141. return x
  142. class SCUNet(nn.Module):
  143. # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
  144. def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
  145. super(SCUNet, self).__init__()
  146. if config is None:
  147. config = [2, 2, 2, 2, 2, 2, 2]
  148. self.config = config
  149. self.dim = dim
  150. self.head_dim = 32
  151. self.window_size = 8
  152. # drop path rate for each layer
  153. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
  154. self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
  155. begin = 0
  156. self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
  157. 'W' if not i % 2 else 'SW', input_resolution)
  158. for i in range(config[0])] + \
  159. [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
  160. begin += config[0]
  161. self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
  162. 'W' if not i % 2 else 'SW', input_resolution // 2)
  163. for i in range(config[1])] + \
  164. [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
  165. begin += config[1]
  166. self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
  167. 'W' if not i % 2 else 'SW', input_resolution // 4)
  168. for i in range(config[2])] + \
  169. [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
  170. begin += config[2]
  171. self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
  172. 'W' if not i % 2 else 'SW', input_resolution // 8)
  173. for i in range(config[3])]
  174. begin += config[3]
  175. self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
  176. [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
  177. 'W' if not i % 2 else 'SW', input_resolution // 4)
  178. for i in range(config[4])]
  179. begin += config[4]
  180. self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
  181. [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
  182. 'W' if not i % 2 else 'SW', input_resolution // 2)
  183. for i in range(config[5])]
  184. begin += config[5]
  185. self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
  186. [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
  187. 'W' if not i % 2 else 'SW', input_resolution)
  188. for i in range(config[6])]
  189. self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
  190. self.m_head = nn.Sequential(*self.m_head)
  191. self.m_down1 = nn.Sequential(*self.m_down1)
  192. self.m_down2 = nn.Sequential(*self.m_down2)
  193. self.m_down3 = nn.Sequential(*self.m_down3)
  194. self.m_body = nn.Sequential(*self.m_body)
  195. self.m_up3 = nn.Sequential(*self.m_up3)
  196. self.m_up2 = nn.Sequential(*self.m_up2)
  197. self.m_up1 = nn.Sequential(*self.m_up1)
  198. self.m_tail = nn.Sequential(*self.m_tail)
  199. # self.apply(self._init_weights)
  200. def forward(self, x0):
  201. h, w = x0.size()[-2:]
  202. paddingBottom = int(np.ceil(h / 64) * 64 - h)
  203. paddingRight = int(np.ceil(w / 64) * 64 - w)
  204. x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
  205. x1 = self.m_head(x0)
  206. x2 = self.m_down1(x1)
  207. x3 = self.m_down2(x2)
  208. x4 = self.m_down3(x3)
  209. x = self.m_body(x4)
  210. x = self.m_up3(x + x4)
  211. x = self.m_up2(x + x3)
  212. x = self.m_up1(x + x2)
  213. x = self.m_tail(x + x1)
  214. x = x[..., :h, :w]
  215. return x
  216. def _init_weights(self, m):
  217. if isinstance(m, nn.Linear):
  218. trunc_normal_(m.weight, std=.02)
  219. if m.bias is not None:
  220. nn.init.constant_(m.bias, 0)
  221. elif isinstance(m, nn.LayerNorm):
  222. nn.init.constant_(m.bias, 0)
  223. nn.init.constant_(m.weight, 1.0)