sd_samplers.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared
  2. # imports for functions that previously were here and are used by other modules
  3. from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
  4. all_samplers = [
  5. *sd_samplers_kdiffusion.samplers_data_k_diffusion,
  6. *sd_samplers_timesteps.samplers_data_timesteps,
  7. *sd_samplers_lcm.samplers_data_lcm,
  8. ]
  9. all_samplers_map = {x.name: x for x in all_samplers}
  10. samplers = []
  11. samplers_for_img2img = []
  12. samplers_map = {}
  13. samplers_hidden = {}
  14. def find_sampler_config(name):
  15. if name is not None:
  16. config = all_samplers_map.get(name, None)
  17. else:
  18. config = all_samplers[0]
  19. return config
  20. def create_sampler(name, model):
  21. config = find_sampler_config(name)
  22. assert config is not None, f'bad sampler name: {name}'
  23. if model.is_sdxl and config.options.get("no_sdxl", False):
  24. raise Exception(f"Sampler {config.name} is not supported for SDXL")
  25. sampler = config.constructor(model)
  26. sampler.config = config
  27. return sampler
  28. def set_samplers():
  29. global samplers, samplers_for_img2img, samplers_hidden
  30. samplers_hidden = set(shared.opts.hide_samplers)
  31. samplers = all_samplers
  32. samplers_for_img2img = all_samplers
  33. samplers_map.clear()
  34. for sampler in all_samplers:
  35. samplers_map[sampler.name.lower()] = sampler.name
  36. for alias in sampler.aliases:
  37. samplers_map[alias.lower()] = sampler.name
  38. def visible_sampler_names():
  39. return [x.name for x in samplers if x.name not in samplers_hidden]
  40. set_samplers()