sd_samplers.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from __future__ import annotations
  2. import functools
  3. import logging
  4. from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
  5. # imports for functions that previously were here and are used by other modules
  6. samples_to_image_grid = sd_samplers_common.samples_to_image_grid
  7. sample_to_image = sd_samplers_common.sample_to_image
  8. all_samplers = [
  9. *sd_samplers_kdiffusion.samplers_data_k_diffusion,
  10. *sd_samplers_timesteps.samplers_data_timesteps,
  11. *sd_samplers_lcm.samplers_data_lcm,
  12. ]
  13. all_samplers_map = {x.name: x for x in all_samplers}
  14. samplers: list[sd_samplers_common.SamplerData] = []
  15. samplers_for_img2img: list[sd_samplers_common.SamplerData] = []
  16. samplers_map = {}
  17. samplers_hidden = {}
  18. def find_sampler_config(name):
  19. if name is not None:
  20. config = all_samplers_map.get(name, None)
  21. else:
  22. config = all_samplers[0]
  23. return config
  24. def create_sampler(name, model):
  25. config = find_sampler_config(name)
  26. assert config is not None, f'bad sampler name: {name}'
  27. if model.is_sdxl and config.options.get("no_sdxl", False):
  28. raise Exception(f"Sampler {config.name} is not supported for SDXL")
  29. sampler = config.constructor(model)
  30. sampler.config = config
  31. return sampler
  32. def set_samplers():
  33. global samplers, samplers_for_img2img, samplers_hidden
  34. samplers_hidden = set(shared.opts.hide_samplers)
  35. samplers = all_samplers
  36. samplers_for_img2img = all_samplers
  37. samplers_map.clear()
  38. for sampler in all_samplers:
  39. samplers_map[sampler.name.lower()] = sampler.name
  40. for alias in sampler.aliases:
  41. samplers_map[alias.lower()] = sampler.name
  42. def visible_sampler_names():
  43. return [x.name for x in samplers if x.name not in samplers_hidden]
  44. def visible_samplers():
  45. return [x for x in samplers if x.name not in samplers_hidden]
  46. def get_sampler_from_infotext(d: dict):
  47. return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
  48. def get_scheduler_from_infotext(d: dict):
  49. return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
  50. def get_hr_sampler_and_scheduler(d: dict):
  51. hr_sampler = d.get("Hires sampler", "Use same sampler")
  52. sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler
  53. hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
  54. scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler
  55. sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)
  56. sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
  57. scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"
  58. return sampler, scheduler
  59. def get_hr_sampler_from_infotext(d: dict):
  60. return get_hr_sampler_and_scheduler(d)[0]
  61. def get_hr_scheduler_from_infotext(d: dict):
  62. return get_hr_sampler_and_scheduler(d)[1]
  63. @functools.cache
  64. def get_sampler_and_scheduler(sampler_name, scheduler_name, *, convert_automatic=True):
  65. default_sampler = samplers[0]
  66. found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
  67. name = sampler_name or default_sampler.name
  68. for scheduler in sd_schedulers.schedulers:
  69. name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
  70. for name_option in name_options:
  71. if name.endswith(" " + name_option):
  72. found_scheduler = scheduler
  73. name = name[0:-(len(name_option) + 1)]
  74. break
  75. sampler = all_samplers_map.get(name, default_sampler)
  76. # revert back to Automatic if it's the default scheduler for the selected sampler
  77. if convert_automatic and sampler.options.get('scheduler', None) == found_scheduler.name:
  78. found_scheduler = sd_schedulers.schedulers[0]
  79. return sampler.name, found_scheduler.label
  80. def fix_p_invalid_sampler_and_scheduler(p):
  81. i_sampler_name, i_scheduler = p.sampler_name, p.scheduler
  82. p.sampler_name, p.scheduler = get_sampler_and_scheduler(p.sampler_name, p.scheduler, convert_automatic=False)
  83. if p.sampler_name != i_sampler_name or i_scheduler != p.scheduler:
  84. logging.warning(f'Sampler Scheduler autocorrection: "{i_sampler_name}" -> "{p.sampler_name}", "{i_scheduler}" -> "{p.scheduler}"')
  85. set_samplers()