Răsfoiți Sursa

add an option to show selected setting in main txt2img/img2img UI
split some code from ui.py into ui_settings.py ui_gradio_edxtensions.py
add before_process callback for scripts
add ability for alwayson scripts to specify section and let user reorder those sections

AUTOMATIC 2 ani în urmă
părinte
comite
df02498d03

+ 48 - 0
extensions-builtin/extra-options-section/scripts/extra_options_section.py

@@ -0,0 +1,48 @@
+import gradio as gr
+from modules import scripts, shared, ui_components, ui_settings
+from modules.ui_components import FormColumn
+
+
+class ExtraOptionsSection(scripts.Script):
+    section = "extra_options"
+
+    def __init__(self):
+        self.comps = None
+        self.setting_names = None
+
+    def title(self):
+        return "Extra options"
+
+    def show(self, is_img2img):
+        return scripts.AlwaysVisible
+
+    def ui(self, is_img2img):
+        self.comps = []
+        self.setting_names = []
+
+        with gr.Blocks() as interface:
+            with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and len(shared.opts.extra_options) > 0 else gr.Group(), gr.Row():
+                for setting_name in shared.opts.extra_options:
+                    with FormColumn():
+                        comp = ui_settings.create_setting_component(setting_name)
+
+                    self.comps.append(comp)
+                    self.setting_names.append(setting_name)
+
+        def get_settings_values():
+            return [ui_settings.get_value_for_setting(key) for key in self.setting_names]
+
+        interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)
+
+        return self.comps
+
+    def before_process(self, p, *args):
+        for name, value in zip(self.setting_names, args):
+            if name not in p.override_settings:
+                p.override_settings[name] = value
+
+
+shared.options_templates.update(shared.options_section(('ui', "User interface"), {
+    "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(),
+    "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion")
+}))

+ 5 - 1
modules/processing.py

@@ -588,11 +588,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
 
 
 
 
 def process_images(p: StableDiffusionProcessing) -> Processed:
 def process_images(p: StableDiffusionProcessing) -> Processed:
+    if p.scripts is not None:
+        p.scripts.before_process(p)
+
     stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
     stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
 
 
     try:
     try:
         # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
         # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
-        if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
+        override_checkpoint = p.override_settings.get('sd_model_checkpoint')
+        if override_checkpoint is not None and sd_models.checkpoint_alisases.get(override_checkpoint) is None:
             p.override_settings.pop('sd_model_checkpoint', None)
             p.override_settings.pop('sd_model_checkpoint', None)
             sd_models.reload_model_weights()
             sd_models.reload_model_weights()
 
 

+ 71 - 45
modules/scripts.py

@@ -19,6 +19,9 @@ class Script:
     name = None
     name = None
     """script's internal name derived from title"""
     """script's internal name derived from title"""
 
 
+    section = None
+    """name of UI section that the script's controls will be placed into"""
+
     filename = None
     filename = None
     args_from = None
     args_from = None
     args_to = None
     args_to = None
@@ -81,6 +84,15 @@ class Script:
 
 
         pass
         pass
 
 
+    def before_process(self, p, *args):
+        """
+        This function is called very early before processing begins for AlwaysVisible scripts.
+        You can modify the processing object (p) here, inject hooks, etc.
+        args contains all values returned by components from ui()
+        """
+
+        pass
+
     def process(self, p, *args):
     def process(self, p, *args):
         """
         """
         This function is called before processing begins for AlwaysVisible scripts.
         This function is called before processing begins for AlwaysVisible scripts.
@@ -293,6 +305,7 @@ class ScriptRunner:
         self.titles = []
         self.titles = []
         self.infotext_fields = []
         self.infotext_fields = []
         self.paste_field_names = []
         self.paste_field_names = []
+        self.inputs = [None]
 
 
     def initialize_scripts(self, is_img2img):
     def initialize_scripts(self, is_img2img):
         from modules import scripts_auto_postprocessing
         from modules import scripts_auto_postprocessing
@@ -320,69 +333,73 @@ class ScriptRunner:
                 self.scripts.append(script)
                 self.scripts.append(script)
                 self.selectable_scripts.append(script)
                 self.selectable_scripts.append(script)
 
 
-    def setup_ui(self):
+    def create_script_ui(self, script):
         import modules.api.models as api_models
         import modules.api.models as api_models
 
 
-        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
+        script.args_from = len(self.inputs)
+        script.args_to = len(self.inputs)
 
 
-        inputs = [None]
-        inputs_alwayson = [True]
+        controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
 
 
-        def create_script_ui(script, inputs, inputs_alwayson):
-            script.args_from = len(inputs)
-            script.args_to = len(inputs)
+        if controls is None:
+            return
 
 
-            controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)
+        script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
+        api_args = []
 
 
-            if controls is None:
-                return
+        for control in controls:
+            control.custom_script_source = os.path.basename(script.filename)
 
 
-            script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
-            api_args = []
+            arg_info = api_models.ScriptArg(label=control.label or "")
 
 
-            for control in controls:
-                control.custom_script_source = os.path.basename(script.filename)
+            for field in ("value", "minimum", "maximum", "step", "choices"):
+                v = getattr(control, field, None)
+                if v is not None:
+                    setattr(arg_info, field, v)
 
 
-                arg_info = api_models.ScriptArg(label=control.label or "")
+            api_args.append(arg_info)
 
 
-                for field in ("value", "minimum", "maximum", "step", "choices"):
-                    v = getattr(control, field, None)
-                    if v is not None:
-                        setattr(arg_info, field, v)
+        script.api_info = api_models.ScriptInfo(
+            name=script.name,
+            is_img2img=script.is_img2img,
+            is_alwayson=script.alwayson,
+            args=api_args,
+        )
 
 
-                api_args.append(arg_info)
+        if script.infotext_fields is not None:
+            self.infotext_fields += script.infotext_fields
 
 
-            script.api_info = api_models.ScriptInfo(
-                name=script.name,
-                is_img2img=script.is_img2img,
-                is_alwayson=script.alwayson,
-                args=api_args,
-            )
+        if script.paste_field_names is not None:
+            self.paste_field_names += script.paste_field_names
 
 
-            if script.infotext_fields is not None:
-                self.infotext_fields += script.infotext_fields
+        self.inputs += controls
+        script.args_to = len(self.inputs)
 
 
-            if script.paste_field_names is not None:
-                self.paste_field_names += script.paste_field_names
+    def setup_ui_for_section(self, section, scriptlist=None):
+        if scriptlist is None:
+            scriptlist = self.alwayson_scripts
 
 
-            inputs += controls
-            inputs_alwayson += [script.alwayson for _ in controls]
-            script.args_to = len(inputs)
+        for script in scriptlist:
+            if script.alwayson and script.section != section:
+                continue
 
 
-        for script in self.alwayson_scripts:
-            with gr.Group() as group:
-                create_script_ui(script, inputs, inputs_alwayson)
+            with gr.Group(visible=script.alwayson) as group:
+                self.create_script_ui(script)
 
 
             script.group = group
             script.group = group
 
 
-        dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
-        inputs[0] = dropdown
+    def prepare_ui(self):
+        self.inputs = [None]
 
 
-        for script in self.selectable_scripts:
-            with gr.Group(visible=False) as group:
-                create_script_ui(script, inputs, inputs_alwayson)
+    def setup_ui(self):
+        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
 
 
-            script.group = group
+        self.setup_ui_for_section(None)
+
+        dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
+        self.inputs[0] = dropdown
+
+        self.setup_ui_for_section(None, self.selectable_scripts)
 
 
         def select_script(script_index):
         def select_script(script_index):
             selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
             selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
@@ -407,6 +424,7 @@ class ScriptRunner:
         )
         )
 
 
         self.script_load_ctr = 0
         self.script_load_ctr = 0
+
         def onload_script_visibility(params):
         def onload_script_visibility(params):
             title = params.get('Script', None)
             title = params.get('Script', None)
             if title:
             if title:
@@ -417,10 +435,10 @@ class ScriptRunner:
             else:
             else:
                 return gr.update(visible=False)
                 return gr.update(visible=False)
 
 
-        self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) )
-        self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] )
+        self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
+        self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
 
 
-        return inputs
+        return self.inputs
 
 
     def run(self, p, *args):
     def run(self, p, *args):
         script_index = args[0]
         script_index = args[0]
@@ -440,6 +458,14 @@ class ScriptRunner:
 
 
         return processed
         return processed
 
 
+    def before_process(self, p):
+        for script in self.alwayson_scripts:
+            try:
+                script_args = p.script_args[script.args_from:script.args_to]
+                script.before_process(p, *script_args)
+            except Exception:
+                errors.report(f"Error running before_process: {script.filename}", exc_info=True)
+
     def process(self, p):
     def process(self, p):
         for script in self.alwayson_scripts:
         for script in self.alwayson_scripts:
             try:
             try:

+ 10 - 0
modules/shared_items.py

@@ -55,5 +55,15 @@ ui_reorder_categories_builtin_items = [
 
 
 
 
 def ui_reorder_categories():
 def ui_reorder_categories():
+    from modules import scripts
+
     yield from ui_reorder_categories_builtin_items
     yield from ui_reorder_categories_builtin_items
+
+    sections = {}
+    for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:
+        if isinstance(script.section, str):
+            sections[script.section] = 1
+
+    yield from sections
+
     yield "scripts"
     yield "scripts"

+ 23 - 328
modules/ui.py

@@ -6,15 +6,17 @@ from functools import reduce
 import warnings
 import warnings
 
 
 import gradio as gr
 import gradio as gr
-import gradio.routes
 import gradio.utils
 import gradio.utils
 import numpy as np
 import numpy as np
 from PIL import Image, PngImagePlugin  # noqa: F401
 from PIL import Image, PngImagePlugin  # noqa: F401
 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
 
 
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings
 from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
 from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
-from modules.paths import script_path, data_path
+from modules.paths import script_path
+from modules.ui_common import create_refresh_button
+from modules.ui_gradio_extensions import reload_javascript
+
 
 
 from modules.shared import opts, cmd_opts
 from modules.shared import opts, cmd_opts
 
 
@@ -34,6 +36,8 @@ import modules.hypernetworks.ui
 from modules.generation_parameters_copypaste import image_from_url_text
 from modules.generation_parameters_copypaste import image_from_url_text
 import modules.extras
 import modules.extras
 
 
+create_setting_component = ui_settings.create_setting_component
+
 warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
 warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)
 
 
 # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
 # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
@@ -366,25 +370,6 @@ def apply_setting(key, value):
     return getattr(opts, key)
     return getattr(opts, key)
 
 
 
 
-def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
-    def refresh():
-        refresh_method()
-        args = refreshed_args() if callable(refreshed_args) else refreshed_args
-
-        for k, v in args.items():
-            setattr(refresh_component, k, v)
-
-        return gr.update(**(args or {}))
-
-    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
-    refresh_button.click(
-        fn=refresh,
-        inputs=[],
-        outputs=[refresh_component]
-    )
-    return refresh_button
-
-
 def create_output_panel(tabname, outdir):
 def create_output_panel(tabname, outdir):
     return ui_common.create_output_panel(tabname, outdir)
     return ui_common.create_output_panel(tabname, outdir)
 
 
@@ -409,16 +394,6 @@ def ordered_ui_categories():
         yield category
         yield category
 
 
 
 
-def get_value_for_setting(key):
-    value = getattr(opts, key)
-
-    info = opts.data_labels[key]
-    args = info.component_args() if callable(info.component_args) else info.component_args or {}
-    args = {k: v for k, v in args.items() if k not in {'precision'}}
-
-    return gr.update(value=value, **args)
-
-
 def create_override_settings_dropdown(tabname, row):
 def create_override_settings_dropdown(tabname, row):
     dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
     dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True)
 
 
@@ -454,6 +429,8 @@ def create_ui():
 
 
         with gr.Row().style(equal_height=False):
         with gr.Row().style(equal_height=False):
             with gr.Column(variant='compact', elem_id="txt2img_settings"):
             with gr.Column(variant='compact', elem_id="txt2img_settings"):
+                modules.scripts.scripts_txt2img.prepare_ui()
+
                 for category in ordered_ui_categories():
                 for category in ordered_ui_categories():
                     if category == "sampler":
                     if category == "sampler":
                         steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
                         steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img")
@@ -522,6 +499,9 @@ def create_ui():
                         with FormGroup(elem_id="txt2img_script_container"):
                         with FormGroup(elem_id="txt2img_script_container"):
                             custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
                             custom_inputs = modules.scripts.scripts_txt2img.setup_ui()
 
 
+                    else:
+                        modules.scripts.scripts_txt2img.setup_ui_for_section(category)
+
             hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
             hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]
 
 
             for component in hr_resolution_preview_inputs:
             for component in hr_resolution_preview_inputs:
@@ -778,6 +758,8 @@ def create_ui():
                 with FormRow():
                 with FormRow():
                     resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
                     resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")
 
 
+                modules.scripts.scripts_img2img.prepare_ui()
+
                 for category in ordered_ui_categories():
                 for category in ordered_ui_categories():
                     if category == "sampler":
                     if category == "sampler":
                         steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
                         steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img")
@@ -887,6 +869,8 @@ def create_ui():
                                     inputs=[],
                                     inputs=[],
                                     outputs=[inpaint_controls, mask_alpha],
                                     outputs=[inpaint_controls, mask_alpha],
                                 )
                                 )
+                    else:
+                        modules.scripts.scripts_img2img.setup_ui_for_section(category)
 
 
             img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
             img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)
 
 
@@ -1460,195 +1444,10 @@ def create_ui():
             outputs=[],
             outputs=[],
         )
         )
 
 
-    def create_setting_component(key, is_quicksettings=False):
-        def fun():
-            return opts.data[key] if key in opts.data else opts.data_labels[key].default
-
-        info = opts.data_labels[key]
-        t = type(info.default)
-
-        args = info.component_args() if callable(info.component_args) else info.component_args
-
-        if info.component is not None:
-            comp = info.component
-        elif t == str:
-            comp = gr.Textbox
-        elif t == int:
-            comp = gr.Number
-        elif t == bool:
-            comp = gr.Checkbox
-        else:
-            raise Exception(f'bad options item type: {t} for key {key}')
-
-        elem_id = f"setting_{key}"
-
-        if info.refresh is not None:
-            if is_quicksettings:
-                res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
-                create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
-            else:
-                with FormRow():
-                    res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
-                    create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
-        else:
-            res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
-
-        return res
-
     loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
     loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
 
 
-    components = []
-    component_dict = {}
-    shared.settings_components = component_dict
-
-    script_callbacks.ui_settings_callback()
-    opts.reorder()
-
-    def run_settings(*args):
-        changed = []
-
-        for key, value, comp in zip(opts.data_labels.keys(), args, components):
-            assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
-
-        for key, value, comp in zip(opts.data_labels.keys(), args, components):
-            if comp == dummy_component:
-                continue
-
-            if opts.set(key, value):
-                changed.append(key)
-
-        try:
-            opts.save(shared.config_filename)
-        except RuntimeError:
-            return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
-        return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
-
-    def run_settings_single(value, key):
-        if not opts.same_type(value, opts.data_labels[key].default):
-            return gr.update(visible=True), opts.dumpjson()
-
-        if not opts.set(key, value):
-            return gr.update(value=getattr(opts, key)), opts.dumpjson()
-
-        opts.save(shared.config_filename)
-
-        return get_value_for_setting(key), opts.dumpjson()
-
-    with gr.Blocks(analytics_enabled=False) as settings_interface:
-        with gr.Row():
-            with gr.Column(scale=6):
-                settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
-            with gr.Column():
-                restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
-
-        result = gr.HTML(elem_id="settings_result")
-
-        quicksettings_names = opts.quicksettings_list
-        quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'}
-
-        quicksettings_list = []
-
-        previous_section = None
-        current_tab = None
-        current_row = None
-        with gr.Tabs(elem_id="settings"):
-            for i, (k, item) in enumerate(opts.data_labels.items()):
-                section_must_be_skipped = item.section[0] is None
-
-                if previous_section != item.section and not section_must_be_skipped:
-                    elem_id, text = item.section
-
-                    if current_tab is not None:
-                        current_row.__exit__()
-                        current_tab.__exit__()
-
-                    gr.Group()
-                    current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
-                    current_tab.__enter__()
-                    current_row = gr.Column(variant='compact')
-                    current_row.__enter__()
-
-                    previous_section = item.section
-
-                if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
-                    quicksettings_list.append((i, k, item))
-                    components.append(dummy_component)
-                elif section_must_be_skipped:
-                    components.append(dummy_component)
-                else:
-                    component = create_setting_component(k)
-                    component_dict[k] = component
-                    components.append(component)
-
-            if current_tab is not None:
-                current_row.__exit__()
-                current_tab.__exit__()
-
-            with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
-                loadsave.create_ui()
-
-            with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
-                request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
-                download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
-                reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
-                with gr.Row():
-                    unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
-                    reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
-
-            with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
-                gr.HTML(shared.html("licenses.html"), elem_id="licenses")
-
-            gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
-
-
-        def unload_sd_weights():
-            modules.sd_models.unload_model_weights()
-
-        def reload_sd_weights():
-            modules.sd_models.reload_model_weights()
-
-        unload_sd_model.click(
-            fn=unload_sd_weights,
-            inputs=[],
-            outputs=[]
-        )
-
-        reload_sd_model.click(
-            fn=reload_sd_weights,
-            inputs=[],
-            outputs=[]
-        )
-
-        request_notifications.click(
-            fn=lambda: None,
-            inputs=[],
-            outputs=[],
-            _js='function(){}'
-        )
-
-        download_localization.click(
-            fn=lambda: None,
-            inputs=[],
-            outputs=[],
-            _js='download_localization'
-        )
-
-        def reload_scripts():
-            modules.scripts.reload_script_body_only()
-            reload_javascript()  # need to refresh the html page
-
-        reload_script_bodies.click(
-            fn=reload_scripts,
-            inputs=[],
-            outputs=[]
-        )
-
-        restart_gradio.click(
-            fn=shared.state.request_restart,
-            _js='restart_reload',
-            inputs=[],
-            outputs=[],
-        )
+    settings = ui_settings.UiSettings()
+    settings.create_ui(loadsave, dummy_component)
 
 
     interfaces = [
     interfaces = [
         (txt2img_interface, "txt2img", "txt2img"),
         (txt2img_interface, "txt2img", "txt2img"),
@@ -1660,7 +1459,7 @@ def create_ui():
     ]
     ]
 
 
     interfaces += script_callbacks.ui_tabs_callback()
     interfaces += script_callbacks.ui_tabs_callback()
-    interfaces += [(settings_interface, "Settings", "settings")]
+    interfaces += [(settings.interface, "Settings", "settings")]
 
 
     extensions_interface = ui_extensions.create_ui()
     extensions_interface = ui_extensions.create_ui()
     interfaces += [(extensions_interface, "Extensions", "extensions")]
     interfaces += [(extensions_interface, "Extensions", "extensions")]
@@ -1670,10 +1469,7 @@ def create_ui():
         shared.tab_names.append(label)
         shared.tab_names.append(label)
 
 
     with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
     with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
-        with gr.Row(elem_id="quicksettings", variant="compact"):
-            for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
-                component = create_setting_component(k, is_quicksettings=True)
-                component_dict[k] = component
+        settings.add_quicksettings()
 
 
         parameters_copypaste.connect_paste_params_buttons()
         parameters_copypaste.connect_paste_params_buttons()
 
 
@@ -1704,49 +1500,12 @@ def create_ui():
         footer = footer.format(versions=versions_html())
         footer = footer.format(versions=versions_html())
         gr.HTML(footer, elem_id="footer")
         gr.HTML(footer, elem_id="footer")
 
 
-        text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
-        settings_submit.click(
-            fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
-            inputs=components,
-            outputs=[text_settings, result],
-        )
-
-        for _i, k, _item in quicksettings_list:
-            component = component_dict[k]
-            info = opts.data_labels[k]
-
-            change_handler = component.release if hasattr(component, 'release') else component.change
-            change_handler(
-                fn=lambda value, k=k: run_settings_single(value, key=k),
-                inputs=[component],
-                outputs=[component, text_settings],
-                show_progress=info.refresh is not None,
-            )
+        settings.add_functionality(demo)
 
 
         update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
         update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
-        text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
+        settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
         demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
         demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
 
 
-        button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
-        button_set_checkpoint.click(
-            fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
-            _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
-            inputs=[component_dict['sd_model_checkpoint'], dummy_component],
-            outputs=[component_dict['sd_model_checkpoint'], text_settings],
-        )
-
-        component_keys = [k for k in opts.data_labels.keys() if k in component_dict]
-
-        def get_settings_values():
-            return [get_value_for_setting(key) for key in component_keys]
-
-        demo.load(
-            fn=get_settings_values,
-            inputs=[],
-            outputs=[component_dict[k] for k in component_keys],
-            queue=False,
-        )
-
         def modelmerger(*args):
         def modelmerger(*args):
             try:
             try:
                 results = modules.extras.run_modelmerger(*args)
                 results = modules.extras.run_modelmerger(*args)
@@ -1779,7 +1538,7 @@ def create_ui():
                 primary_model_name,
                 primary_model_name,
                 secondary_model_name,
                 secondary_model_name,
                 tertiary_model_name,
                 tertiary_model_name,
-                component_dict['sd_model_checkpoint'],
+                settings.component_dict['sd_model_checkpoint'],
                 modelmerger_result,
                 modelmerger_result,
             ]
             ]
         )
         )
@@ -1793,70 +1552,6 @@ def create_ui():
     return demo
     return demo
 
 
 
 
-def webpath(fn):
-    if fn.startswith(script_path):
-        web_path = os.path.relpath(fn, script_path).replace('\\', '/')
-    else:
-        web_path = os.path.abspath(fn)
-
-    return f'file={web_path}?{os.path.getmtime(fn)}'
-
-
-def javascript_html():
-    # Ensure localization is in `window` before scripts
-    head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
-
-    script_js = os.path.join(script_path, "script.js")
-    head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
-
-    for script in modules.scripts.list_scripts("javascript", ".js"):
-        head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
-
-    for script in modules.scripts.list_scripts("javascript", ".mjs"):
-        head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
-
-    if cmd_opts.theme:
-        head += f'<script type="text/javascript">set_theme(\"{cmd_opts.theme}\");</script>\n'
-
-    return head
-
-
-def css_html():
-    head = ""
-
-    def stylesheet(fn):
-        return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
-
-    for cssfile in modules.scripts.list_files_with_name("style.css"):
-        if not os.path.isfile(cssfile):
-            continue
-
-        head += stylesheet(cssfile)
-
-    if os.path.exists(os.path.join(data_path, "user.css")):
-        head += stylesheet(os.path.join(data_path, "user.css"))
-
-    return head
-
-
-def reload_javascript():
-    js = javascript_html()
-    css = css_html()
-
-    def template_response(*args, **kwargs):
-        res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
-        res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
-        res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
-        res.init_headers()
-        return res
-
-    gradio.routes.templates.TemplateResponse = template_response
-
-
-if not hasattr(shared, 'GradioTemplateResponseOriginal'):
-    shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse
-
-
 def versions_html():
 def versions_html():
     import torch
     import torch
     import launch
     import launch

+ 23 - 0
modules/ui_common.py

@@ -10,8 +10,11 @@ import subprocess as sp
 from modules import call_queue, shared
 from modules import call_queue, shared
 from modules.generation_parameters_copypaste import image_from_url_text
 from modules.generation_parameters_copypaste import image_from_url_text
 import modules.images
 import modules.images
+from modules.ui_components import ToolButton
+
 
 
 folder_symbol = '\U0001f4c2'  # 📂
 folder_symbol = '\U0001f4c2'  # 📂
+refresh_symbol = '\U0001f504'  # 🔄
 
 
 
 
 def update_generation_info(generation_info, html_info, img_index):
 def update_generation_info(generation_info, html_info, img_index):
@@ -216,3 +219,23 @@ Requested path was: {f}
                 ))
                 ))
 
 
             return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
             return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+
+
+def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+    def refresh():
+        refresh_method()
+        args = refreshed_args() if callable(refreshed_args) else refreshed_args
+
+        for k, v in args.items():
+            setattr(refresh_component, k, v)
+
+        return gr.update(**(args or {}))
+
+    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+    refresh_button.click(
+        fn=refresh,
+        inputs=[],
+        outputs=[refresh_component]
+    )
+    return refresh_button
+

+ 69 - 0
modules/ui_gradio_extensions.py

@@ -0,0 +1,69 @@
+import os
+import gradio as gr
+
+from modules import localization, shared, scripts
+from modules.paths import script_path, data_path
+
+
+def webpath(fn):
+    if fn.startswith(script_path):
+        web_path = os.path.relpath(fn, script_path).replace('\\', '/')
+    else:
+        web_path = os.path.abspath(fn)
+
+    return f'file={web_path}?{os.path.getmtime(fn)}'
+
+
+def javascript_html():
+    # Ensure localization is in `window` before scripts
+    head = f'<script type="text/javascript">{localization.localization_js(shared.opts.localization)}</script>\n'
+
+    script_js = os.path.join(script_path, "script.js")
+    head += f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
+
+    for script in scripts.list_scripts("javascript", ".js"):
+        head += f'<script type="text/javascript" src="{webpath(script.path)}"></script>\n'
+
+    for script in scripts.list_scripts("javascript", ".mjs"):
+        head += f'<script type="module" src="{webpath(script.path)}"></script>\n'
+
+    if shared.cmd_opts.theme:
+        head += f'<script type="text/javascript">set_theme(\"{shared.cmd_opts.theme}\");</script>\n'
+
+    return head
+
+
+def css_html():
+    head = ""
+
+    def stylesheet(fn):
+        return f'<link rel="stylesheet" property="stylesheet" href="{webpath(fn)}">'
+
+    for cssfile in scripts.list_files_with_name("style.css"):
+        if not os.path.isfile(cssfile):
+            continue
+
+        head += stylesheet(cssfile)
+
+    if os.path.exists(os.path.join(data_path, "user.css")):
+        head += stylesheet(os.path.join(data_path, "user.css"))
+
+    return head
+
+
+def reload_javascript():
+    js = javascript_html()
+    css = css_html()
+
+    def template_response(*args, **kwargs):
+        res = shared.GradioTemplateResponseOriginal(*args, **kwargs)
+        res.body = res.body.replace(b'</head>', f'{js}</head>'.encode("utf8"))
+        res.body = res.body.replace(b'</body>', f'{css}</body>'.encode("utf8"))
+        res.init_headers()
+        return res
+
+    gr.routes.templates.TemplateResponse = template_response
+
+
+if not hasattr(shared, 'GradioTemplateResponseOriginal'):
+    shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse

+ 263 - 0
modules/ui_settings.py

@@ -0,0 +1,263 @@
+import gradio as gr
+
+from modules import ui_common, shared, script_callbacks, scripts, sd_models
+from modules.call_queue import wrap_gradio_call
+from modules.shared import opts
+from modules.ui_components import FormRow
+from modules.ui_gradio_extensions import reload_javascript
+
+
+def get_value_for_setting(key):
+    value = getattr(opts, key)
+
+    info = opts.data_labels[key]
+    args = info.component_args() if callable(info.component_args) else info.component_args or {}
+    args = {k: v for k, v in args.items() if k not in {'precision'}}
+
+    return gr.update(value=value, **args)
+
+
+def create_setting_component(key, is_quicksettings=False):
+    def fun():
+        return opts.data[key] if key in opts.data else opts.data_labels[key].default
+
+    info = opts.data_labels[key]
+    t = type(info.default)
+
+    args = info.component_args() if callable(info.component_args) else info.component_args
+
+    if info.component is not None:
+        comp = info.component
+    elif t == str:
+        comp = gr.Textbox
+    elif t == int:
+        comp = gr.Number
+    elif t == bool:
+        comp = gr.Checkbox
+    else:
+        raise Exception(f'bad options item type: {t} for key {key}')
+
+    elem_id = f"setting_{key}"
+
+    if info.refresh is not None:
+        if is_quicksettings:
+            res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+            ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
+        else:
+            with FormRow():
+                res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+                ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}")
+    else:
+        res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))
+
+    return res
+
+
+class UiSettings:
+    submit = None
+    result = None
+    interface = None
+    components = None
+    component_dict = None
+    dummy_component = None
+    quicksettings_list = None
+    quicksettings_names = None
+    text_settings = None
+
+    def run_settings(self, *args):
+        changed = []
+
+        for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
+            assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
+
+        for key, value, comp in zip(opts.data_labels.keys(), args, self.components):
+            if comp == self.dummy_component:
+                continue
+
+            if opts.set(key, value):
+                changed.append(key)
+
+        try:
+            opts.save(shared.config_filename)
+        except RuntimeError:
+            return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
+        return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.'
+
+    def run_settings_single(self, value, key):
+        if not opts.same_type(value, opts.data_labels[key].default):
+            return gr.update(visible=True), opts.dumpjson()
+
+        if not opts.set(key, value):
+            return gr.update(value=getattr(opts, key)), opts.dumpjson()
+
+        opts.save(shared.config_filename)
+
+        return get_value_for_setting(key), opts.dumpjson()
+
+    def create_ui(self, loadsave, dummy_component):
+        self.components = []
+        self.component_dict = {}
+        self.dummy_component = dummy_component
+
+        shared.settings_components = self.component_dict
+
+        script_callbacks.ui_settings_callback()
+        opts.reorder()
+
+        with gr.Blocks(analytics_enabled=False) as settings_interface:
+            with gr.Row():
+                with gr.Column(scale=6):
+                    self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit")
+                with gr.Column():
+                    restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio")
+
+            self.result = gr.HTML(elem_id="settings_result")
+
+            self.quicksettings_names = opts.quicksettings_list
+            self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'}
+
+            self.quicksettings_list = []
+
+            previous_section = None
+            current_tab = None
+            current_row = None
+            with gr.Tabs(elem_id="settings"):
+                for i, (k, item) in enumerate(opts.data_labels.items()):
+                    section_must_be_skipped = item.section[0] is None
+
+                    if previous_section != item.section and not section_must_be_skipped:
+                        elem_id, text = item.section
+
+                        if current_tab is not None:
+                            current_row.__exit__()
+                            current_tab.__exit__()
+
+                        gr.Group()
+                        current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text)
+                        current_tab.__enter__()
+                        current_row = gr.Column(variant='compact')
+                        current_row.__enter__()
+
+                        previous_section = item.section
+
+                    if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings:
+                        self.quicksettings_list.append((i, k, item))
+                        self.components.append(dummy_component)
+                    elif section_must_be_skipped:
+                        self.components.append(dummy_component)
+                    else:
+                        component = create_setting_component(k)
+                        self.component_dict[k] = component
+                        self.components.append(component)
+
+                if current_tab is not None:
+                    current_row.__exit__()
+                    current_tab.__exit__()
+
+                with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
+                    loadsave.create_ui()
+
+                with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
+                    request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
+                    download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
+                    reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
+                    with gr.Row():
+                        unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
+                        reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
+
+                with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"):
+                    gr.HTML(shared.html("licenses.html"), elem_id="licenses")
+
+                gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
+
+                self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
+
+            unload_sd_model.click(
+                fn=sd_models.unload_model_weights,
+                inputs=[],
+                outputs=[]
+            )
+
+            reload_sd_model.click(
+                fn=sd_models.reload_model_weights,
+                inputs=[],
+                outputs=[]
+            )
+
+            request_notifications.click(
+                fn=lambda: None,
+                inputs=[],
+                outputs=[],
+                _js='function(){}'
+            )
+
+            download_localization.click(
+                fn=lambda: None,
+                inputs=[],
+                outputs=[],
+                _js='download_localization'
+            )
+
+            def reload_scripts():
+                scripts.reload_script_body_only()
+                reload_javascript()  # need to refresh the html page
+
+            reload_script_bodies.click(
+                fn=reload_scripts,
+                inputs=[],
+                outputs=[]
+            )
+
+            restart_gradio.click(
+                fn=shared.state.request_restart,
+                _js='restart_reload',
+                inputs=[],
+                outputs=[],
+            )
+
+        self.interface = settings_interface
+
+    def add_quicksettings(self):
+        with gr.Row(elem_id="quicksettings", variant="compact"):
+            for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])):
+                component = create_setting_component(k, is_quicksettings=True)
+                self.component_dict[k] = component
+
+    def add_functionality(self, demo):
+        self.submit.click(
+            fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),
+            inputs=self.components,
+            outputs=[self.text_settings, self.result],
+        )
+
+        for _i, k, _item in self.quicksettings_list:
+            component = self.component_dict[k]
+            info = opts.data_labels[k]
+
+            change_handler = component.release if hasattr(component, 'release') else component.change
+            change_handler(
+                fn=lambda value, k=k: self.run_settings_single(value, key=k),
+                inputs=[component],
+                outputs=[component, self.text_settings],
+                show_progress=info.refresh is not None,
+            )
+
+        button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
+        button_set_checkpoint.click(
+            fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),
+            _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
+            inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],
+            outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],
+        )
+
+        component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]
+
+        def get_settings_values():
+            return [get_value_for_setting(key) for key in component_keys]
+
+        demo.load(
+            fn=get_settings_values,
+            inputs=[],
+            outputs=[self.component_dict[k] for k in component_keys],
+            queue=False,
+        )