Browse Source

Add SSD-1B as a supported model

Ritesh Gangnani 1 year ago
parent
commit
ff1609f91e
3 changed files with 21 additions and 3 deletions
  1. 11 0
      modules/sd_hijack.py
  2. 6 2
      modules/sd_models.py
  3. 4 1
      modules/sd_models_types.py

+ 11 - 0
modules/sd_hijack.py

@@ -180,6 +180,17 @@ class StableDiffusionModelHijack:
         except Exception as e:
         except Exception as e:
             errors.display(e, "applying cross attention optimization")
             errors.display(e, "applying cross attention optimization")
             undo_optimizations()
             undo_optimizations()
+            
+    def conv_ssd(self, m):
+            delattr(m.model.diffusion_model.middle_block, '1')
+            delattr(m.model.diffusion_model.middle_block, '2')
+            for i in ['9','8','7','6','5','4']:
+                delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks,i)
+                delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks,i)
+                delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks,i)
+                delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks,i)
+            delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks,'1')
+            delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1') 
 
 
     def hijack(self, m):
     def hijack(self, m):
         conditioner = getattr(m, 'conditioner', None)
         conditioner = getattr(m, 'conditioner', None)

+ 6 - 2
modules/sd_models.py

@@ -346,10 +346,14 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
     model.is_sdxl = hasattr(model, 'conditioner')
     model.is_sdxl = hasattr(model, 'conditioner')
     model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
     model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
     model.is_sd1 = not model.is_sdxl and not model.is_sd2
     model.is_sd1 = not model.is_sdxl and not model.is_sd2
-
+    model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
+    
     if model.is_sdxl:
     if model.is_sdxl:
         sd_models_xl.extend_sdxl(model)
         sd_models_xl.extend_sdxl(model)
-
+    
+    if model.is_ssd:
+        sd_hijack.model_hijack.conv_ssd(model)
+        
     model.load_state_dict(state_dict, strict=False)
     model.load_state_dict(state_dict, strict=False)
     timer.record("apply weights to model")
     timer.record("apply weights to model")
 
 

+ 4 - 1
modules/sd_models_types.py

@@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion):
     """structure with additional information about the file with model's weights"""
     """structure with additional information about the file with model's weights"""
 
 
     is_sdxl: bool
     is_sdxl: bool
-    """True if the model's architecture is SDXL"""
+    """True if the model's architecture is SDXL or SSD"""
+    
+    is_ssd: bool
+    """True if the model is SSD"""
 
 
     is_sd2: bool
     is_sd2: bool
     """True if the model's architecture is SD 2.x"""
     """True if the model's architecture is SD 2.x"""