123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- from __future__ import annotations
- import functools
- from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
- # imports for functions that previously were here and are used by other modules
- samples_to_image_grid = sd_samplers_common.samples_to_image_grid
- sample_to_image = sd_samplers_common.sample_to_image
- all_samplers = [
- *sd_samplers_kdiffusion.samplers_data_k_diffusion,
- *sd_samplers_timesteps.samplers_data_timesteps,
- *sd_samplers_lcm.samplers_data_lcm,
- ]
- all_samplers_map = {x.name: x for x in all_samplers}
- samplers: list[sd_samplers_common.SamplerData] = []
- samplers_for_img2img: list[sd_samplers_common.SamplerData] = []
- samplers_map = {}
- samplers_hidden = {}
- def find_sampler_config(name):
- if name is not None:
- config = all_samplers_map.get(name, None)
- else:
- config = all_samplers[0]
- return config
- def create_sampler(name, model):
- config = find_sampler_config(name)
- assert config is not None, f'bad sampler name: {name}'
- if model.is_sdxl and config.options.get("no_sdxl", False):
- raise Exception(f"Sampler {config.name} is not supported for SDXL")
- sampler = config.constructor(model)
- sampler.config = config
- return sampler
- def set_samplers():
- global samplers, samplers_for_img2img, samplers_hidden
- samplers_hidden = set(shared.opts.hide_samplers)
- samplers = all_samplers
- samplers_for_img2img = all_samplers
- samplers_map.clear()
- for sampler in all_samplers:
- samplers_map[sampler.name.lower()] = sampler.name
- for alias in sampler.aliases:
- samplers_map[alias.lower()] = sampler.name
- def visible_sampler_names():
- return [x.name for x in samplers if x.name not in samplers_hidden]
- def visible_samplers():
- return [x for x in samplers if x.name not in samplers_hidden]
- def get_sampler_from_infotext(d: dict):
- return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
- def get_scheduler_from_infotext(d: dict):
- return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
- def get_hr_sampler_and_scheduler(d: dict):
- hr_sampler = d.get("Hires sampler", "Use same sampler")
- sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler
- hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
- scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler
- sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)
- sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
- scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"
- return sampler, scheduler
- def get_hr_sampler_from_infotext(d: dict):
- return get_hr_sampler_and_scheduler(d)[0]
- def get_hr_scheduler_from_infotext(d: dict):
- return get_hr_sampler_and_scheduler(d)[1]
- @functools.cache
- def get_sampler_and_scheduler(sampler_name, scheduler_name):
- default_sampler = samplers[0]
- found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
- name = sampler_name or default_sampler.name
- for scheduler in sd_schedulers.schedulers:
- name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
- for name_option in name_options:
- if name.endswith(" " + name_option):
- found_scheduler = scheduler
- name = name[0:-(len(name_option) + 1)]
- break
- sampler = all_samplers_map.get(name, default_sampler)
- # revert back to Automatic if it's the default scheduler for the selected sampler
- if sampler.options.get('scheduler', None) == found_scheduler.name:
- found_scheduler = sd_schedulers.schedulers[0]
- return sampler.name, found_scheduler.label
- set_samplers()
|