refiner.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import gradio as gr
  2. from modules import scripts, sd_models
  3. from modules.infotext import PasteField
  4. from modules.ui_common import create_refresh_button
  5. from modules.ui_components import InputAccordion
  6. class ScriptRefiner(scripts.ScriptBuiltinUI):
  7. section = "accordions"
  8. create_group = False
  9. def __init__(self):
  10. pass
  11. def title(self):
  12. return "Refiner"
  13. def show(self, is_img2img):
  14. return scripts.AlwaysVisible
  15. def ui(self, is_img2img):
  16. with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner:
  17. with gr.Row():
  18. refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation")
  19. create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh"))
  20. refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation")
  21. def lookup_checkpoint(title):
  22. info = sd_models.get_closet_checkpoint_match(title)
  23. return None if info is None else info.title
  24. self.infotext_fields = [
  25. PasteField(enable_refiner, lambda d: 'Refiner' in d),
  26. PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"),
  27. PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"),
  28. ]
  29. return enable_refiner, refiner_checkpoint, refiner_switch_at
  30. def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
  31. # the actual implementation is in sd_samplers_common.py, apply_refiner
  32. if not enable_refiner or refiner_checkpoint in (None, "", "None"):
  33. p.refiner_checkpoint = None
  34. p.refiner_switch_at = None
  35. else:
  36. p.refiner_checkpoint = refiner_checkpoint
  37. p.refiner_switch_at = refiner_switch_at