Browse Source

resolve some of circular import issues for kohaku

AUTOMATIC1111 2 years ago
parent
commit
f0c1063a70

+ 2 - 3
modules/hypernetworks/hypernetwork.py

@@ -10,7 +10,7 @@ import torch
 import tqdm
 import tqdm
 from einops import rearrange, repeat
 from einops import rearrange, repeat
 from ldm.util import default
 from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
+from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
 from modules.textual_inversion import textual_inversion, logging
 from modules.textual_inversion import textual_inversion, logging
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
 from torch import einsum
 from torch import einsum
@@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
 
 
 
 
 def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
 def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
-    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
-    from modules import images
+    from modules import images, processing
 
 
     save_hypernetwork_every = save_hypernetwork_every or 0
     save_hypernetwork_every = save_hypernetwork_every or 0
     create_image_every = create_image_every or 0
     create_image_every = create_image_every or 0

+ 1 - 6
modules/processing.py

@@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
 from einops import repeat, rearrange
 from einops import repeat, rearrange
 from blendmodes.blend import blendLayers, BlendType
 from blendmodes.blend import blendLayers, BlendType
 
 
+decode_first_stage = sd_samplers_common.decode_first_stage
 
 
 # some of those options should not be changed at all because they would break the model, so I removed them from options.
 # some of those options should not be changed at all because they would break the model, so I removed them from options.
 opt_C = 4
 opt_C = 4
@@ -572,12 +573,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
     return samples
     return samples
 
 
 
 
-def decode_first_stage(model, x):
-    x = model.decode_first_stage(x.to(devices.dtype_vae))
-
-    return x
-
-
 def get_fixed_seed(seed):
 def get_fixed_seed(seed):
     if seed is None or seed == '' or seed == -1:
     if seed is None or seed == '' or seed == -1:
         return int(random.randrange(4294967294))
         return int(random.randrange(4294967294))

+ 3 - 3
modules/sd_hijack.py

@@ -2,7 +2,6 @@ import torch
 from torch.nn.functional import silu
 from torch.nn.functional import silu
 from types import MethodType
 from types import MethodType
 
 
-import modules.textual_inversion.textual_inversion
 from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
 from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
 from modules.hypernetworks import hypernetwork
 from modules.hypernetworks import hypernetwork
 from modules.shared import cmd_opts
 from modules.shared import cmd_opts
@@ -164,12 +163,13 @@ class StableDiffusionModelHijack:
     clip = None
     clip = None
     optimization_method = None
     optimization_method = None
 
 
-    embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
-
     def __init__(self):
     def __init__(self):
+        import modules.textual_inversion.textual_inversion
+
         self.extra_generation_params = {}
         self.extra_generation_params = {}
         self.comments = []
         self.comments = []
 
 
+        self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
         self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
         self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
 
 
     def apply_optimizations(self, option=None):
     def apply_optimizations(self, option=None):

+ 8 - 2
modules/sd_samplers_common.py

@@ -2,7 +2,7 @@ from collections import namedtuple
 import numpy as np
 import numpy as np
 import torch
 import torch
 from PIL import Image
 from PIL import Image
-from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
+from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
 from modules.shared import opts, state
 from modules.shared import opts, state
 
 
 SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
 SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -35,7 +35,7 @@ def single_sample_to_image(sample, approximation=None):
         x_sample = sample * 1.5
         x_sample = sample * 1.5
         x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
         x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
     else:
     else:
-        x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
+        x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5
 
 
     x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
     x_sample = torch.clamp(x_sample, min=0.0, max=1.0)
     x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
     x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
@@ -44,6 +44,12 @@ def single_sample_to_image(sample, approximation=None):
     return Image.fromarray(x_sample)
     return Image.fromarray(x_sample)
 
 
 
 
+def decode_first_stage(model, x):
+    x = model.decode_first_stage(x.to(devices.dtype_vae))
+
+    return x
+
+
 def sample_to_image(samples, index=0, approximation=None):
 def sample_to_image(samples, index=0, approximation=None):
     return single_sample_to_image(samples[index], approximation)
     return single_sample_to_image(samples[index], approximation)
 
 

+ 3 - 1
modules/textual_inversion/textual_inversion.py

@@ -13,7 +13,7 @@ import numpy as np
 from PIL import Image, PngImagePlugin
 from PIL import Image, PngImagePlugin
 from torch.utils.tensorboard import SummaryWriter
 from torch.utils.tensorboard import SummaryWriter
 
 
-from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
+from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
 import modules.textual_inversion.dataset
 import modules.textual_inversion.dataset
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
 from modules.textual_inversion.learn_schedule import LearnRateScheduler
 
 
@@ -387,6 +387,8 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
 
 
 
 
 def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
 def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+    from modules import processing
+
     save_embedding_every = save_embedding_every or 0
     save_embedding_every = save_embedding_every or 0
     create_image_every = create_image_every or 0
     create_image_every = create_image_every or 0
     template_file = textual_inversion_templates.get(template_filename, None)
     template_file = textual_inversion_templates.get(template_filename, None)