Browse Source

added a button to run hires fix on selected image in the gallery

AUTOMATIC1111 1 year ago
parent
commit
501993eb
6 changed files with 160 additions and 86 deletions
  1. 8 0
      javascript/ui.js
  2. 37 9
      modules/processing.py
  3. 17 2
      modules/txt2img.py
  4. 59 49
      modules/ui.py
  5. 35 22
      modules/ui_common.py
  6. 4 4
      modules/ui_postprocessing.py

+ 8 - 0
javascript/ui.js

@@ -150,6 +150,14 @@ function submit() {
     return res;
 }
 
+function submit_txt2img_upscale() {
+    res = submit.apply(null, arguments);
+
+    res[2] = selected_gallery_index();
+
+    return res;
+}
+
 function submit_img2img() {
     showSubmitButtons('img2img', false);
 

+ 37 - 9
modules/processing.py

@@ -179,6 +179,7 @@ class StableDiffusionProcessing:
     token_merging_ratio = 0
     token_merging_ratio_hr = 0
     disable_extra_networks: bool = False
+    firstpass_image: Image = None
 
     scripts_value: scripts.ScriptRunner = field(default=None, init=False)
     script_args_value: list = field(default=None, init=False)
@@ -1238,18 +1239,45 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
     def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
         self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
 
-        x = self.rng.next()
-        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
-        del x
+        if self.firstpass_image is not None and self.enable_hr:
+            # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix
 
-        if not self.enable_hr:
-            return samples
-        devices.torch_gc()
+            if self.latent_scale_mode is None:
+                image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0
+                image = np.moveaxis(image, 2, 0)
+
+                samples = None
+                decoded_samples = torch.asarray(np.expand_dims(image, 0))
+
+            else:
+                image = np.array(self.firstpass_image).astype(np.float32) / 255.0
+                image = np.moveaxis(image, 2, 0)
+                image = torch.from_numpy(np.expand_dims(image, axis=0))
+                image = image.to(shared.device, dtype=devices.dtype_vae)
+
+                if opts.sd_vae_encode_method != 'Full':
+                    self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
+
+                samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
+                decoded_samples = None
+                devices.torch_gc()
 
-        if self.latent_scale_mode is None:
-            decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
         else:
-            decoded_samples = None
+            # here we generate an image normally
+
+            x = self.rng.next()
+            samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+            del x
+
+            if not self.enable_hr:
+                return samples
+
+            devices.torch_gc()
+
+            if self.latent_scale_mode is None:
+                decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
+            else:
+                decoded_samples = None
 
         with sd_models.SkipWritingToConfig():
             sd_models.reload_model_weights(info=self.hr_checkpoint_info)

+ 17 - 2
modules/txt2img.py

@@ -1,7 +1,7 @@
 from contextlib import closing
 
 import modules.scripts
-from modules import processing
+from modules import processing, infotext_utils
 from modules.infotext_utils import create_override_settings_dict
 from modules.shared import opts
 import modules.shared as shared
@@ -9,9 +9,23 @@ from modules.ui import plaintext_to_html
 import gradio as gr
 
 
-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
+def txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, *args):
+    assert len(gallery) > 0, 'No image to upscale'
+
+    image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]
+    image = infotext_utils.image_from_url_text(image_info)
+
+    return txt2img(id_task, request, *args, firstpass_image=image)
+
+
+def txt2img(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, firstpass_image=None):
     override_settings = create_override_settings_dict(override_settings_texts)
 
+    if firstpass_image is not None:
+        enable_hr = True
+        batch_size = 1
+        n_iter = 1
+
     p = processing.StableDiffusionProcessingTxt2Img(
         sd_model=shared.sd_model,
         outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
@@ -38,6 +52,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
         hr_prompt=hr_prompt,
         hr_negative_prompt=hr_negative_prompt,
         override_settings=override_settings,
+        firstpass_image=firstpass_image,
     )
 
     p.scripts = modules.scripts.scripts_txt2img

+ 59 - 49
modules/ui.py

@@ -375,50 +375,60 @@ def create_ui():
                     show_progress=False,
                 )
 
-            txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
+            output_panel = create_output_panel("txt2img", opts.outdir_txt2img_samples, toprow)
+
+            txt2img_inputs = [
+                dummy_component,
+                toprow.prompt,
+                toprow.negative_prompt,
+                toprow.ui_styles.dropdown,
+                steps,
+                sampler_name,
+                batch_count,
+                batch_size,
+                cfg_scale,
+                height,
+                width,
+                enable_hr,
+                denoising_strength,
+                hr_scale,
+                hr_upscaler,
+                hr_second_pass_steps,
+                hr_resize_x,
+                hr_resize_y,
+                hr_checkpoint_name,
+                hr_sampler_name,
+                hr_prompt,
+                hr_negative_prompt,
+                override_settings,
+            ] + custom_inputs
+
+            txt2img_outputs = [
+                output_panel.gallery,
+                output_panel.infotext,
+                output_panel.html_info,
+                output_panel.html_log,
+            ]
 
             txt2img_args = dict(
                 fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
                 _js="submit",
-                inputs=[
-                    dummy_component,
-                    toprow.prompt,
-                    toprow.negative_prompt,
-                    toprow.ui_styles.dropdown,
-                    steps,
-                    sampler_name,
-                    batch_count,
-                    batch_size,
-                    cfg_scale,
-                    height,
-                    width,
-                    enable_hr,
-                    denoising_strength,
-                    hr_scale,
-                    hr_upscaler,
-                    hr_second_pass_steps,
-                    hr_resize_x,
-                    hr_resize_y,
-                    hr_checkpoint_name,
-                    hr_sampler_name,
-                    hr_prompt,
-                    hr_negative_prompt,
-                    override_settings,
-
-                ] + custom_inputs,
-
-                outputs=[
-                    txt2img_gallery,
-                    generation_info,
-                    html_info,
-                    html_log,
-                ],
+                inputs=txt2img_inputs,
+                outputs=txt2img_outputs,
                 show_progress=False,
             )
 
             toprow.prompt.submit(**txt2img_args)
             toprow.submit.click(**txt2img_args)
 
+            output_panel.button_upscale.click(
+                fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),
+                _js="submit_txt2img_upscale",
+                inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component] + txt2img_inputs[1:],
+                outputs=txt2img_outputs,
+                show_progress=False,
+            )
+
             res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
 
             toprow.restore_progress_button.click(
@@ -426,10 +436,10 @@ def create_ui():
                 _js="restoreProgressTxt2img",
                 inputs=[dummy_component],
                 outputs=[
-                    txt2img_gallery,
-                    generation_info,
-                    html_info,
-                    html_log,
+                    output_panel.gallery,
+                    output_panel.infotext,
+                    output_panel.html_info,
+                    output_panel.html_log,
                 ],
                 show_progress=False,
             )
@@ -479,7 +489,7 @@ def create_ui():
             toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
 
         extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
-        ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
+        ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery)
 
         extra_tabs.__exit__()
 
@@ -710,7 +720,7 @@ def create_ui():
                     outputs=[inpaint_controls, mask_alpha],
                 )
 
-            img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
+            output_panel = create_output_panel("img2img", opts.outdir_img2img_samples, toprow)
 
             img2img_args = dict(
                 fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
@@ -755,10 +765,10 @@ def create_ui():
                     img2img_batch_png_info_dir,
                 ] + custom_inputs,
                 outputs=[
-                    img2img_gallery,
-                    generation_info,
-                    html_info,
-                    html_log,
+                    output_panel.gallery,
+                    output_panel.infotext,
+                    output_panel.html_info,
+                    output_panel.html_log,
                 ],
                 show_progress=False,
             )
@@ -796,10 +806,10 @@ def create_ui():
                 _js="restoreProgressImg2img",
                 inputs=[dummy_component],
                 outputs=[
-                    img2img_gallery,
-                    generation_info,
-                    html_info,
-                    html_log,
+                    output_panel.gallery,
+                    output_panel.infotext,
+                    output_panel.html_info,
+                    output_panel.html_log,
                 ],
                 show_progress=False,
             )
@@ -839,7 +849,7 @@ def create_ui():
             ))
 
         extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
-        ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
+        ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery)
 
         extra_tabs.__exit__()
 

+ 35 - 22
modules/ui_common.py

@@ -1,3 +1,4 @@
+import dataclasses
 import json
 import html
 import os
@@ -104,7 +105,17 @@ def save_files(js_data, images, do_make_zip, index):
     return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}")
 
 
+@dataclasses.dataclass
+class OutputPanel:
+    gallery = None
+    infotext = None
+    html_info = None
+    html_log = None
+    button_upscale = None
+
+
 def create_output_panel(tabname, outdir, toprow=None):
+    res = OutputPanel()
 
     def open_folder(f):
         if not os.path.exists(f):
@@ -136,9 +147,8 @@ Requested path was: {f}
 
         with gr.Column(variant='panel', elem_id=f"{tabname}_results_panel"):
             with gr.Group(elem_id=f"{tabname}_gallery_container"):
-                result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
+                res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4, preview=True, height=shared.opts.gallery_height or None)
 
-            generation_info = None
             with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"):
                 open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip="Open images output directory.")
 
@@ -152,6 +162,9 @@ Requested path was: {f}
                     'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.")
                 }
 
+                if tabname == 'txt2img':
+                    res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip="Create an upscaled version of the current image using hires fix settings.")
+
             open_folder_button.click(
                 fn=lambda: open_folder(shared.opts.outdir_samples or outdir),
                 inputs=[],
@@ -162,17 +175,17 @@ Requested path was: {f}
                 download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')
 
                 with gr.Group():
-                    html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
-                    html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
+                    res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
+                    res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes="html-log")
 
-                    generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
+                    res.infotext = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')
                     if tabname == 'txt2img' or tabname == 'img2img':
                         generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button")
                         generation_info_button.click(
                             fn=update_generation_info,
                             _js="function(x, y, z){ return [x, y, selected_gallery_index()] }",
-                            inputs=[generation_info, html_info, html_info],
-                            outputs=[html_info, html_info],
+                            inputs=[res.infotext, res.html_info, res.html_info],
+                            outputs=[res.html_info, res.html_info],
                             show_progress=False,
                         )
 
@@ -180,14 +193,14 @@ Requested path was: {f}
                         fn=call_queue.wrap_gradio_call(save_files),
                         _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]",
                         inputs=[
-                            generation_info,
-                            result_gallery,
-                            html_info,
-                            html_info,
+                            res.infotext,
+                            res.gallery,
+                            res.html_info,
+                            res.html_info,
                         ],
                         outputs=[
                             download_files,
-                            html_log,
+                            res.html_log,
                         ],
                         show_progress=False,
                     )
@@ -196,21 +209,21 @@ Requested path was: {f}
                         fn=call_queue.wrap_gradio_call(save_files),
                         _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]",
                         inputs=[
-                            generation_info,
-                            result_gallery,
-                            html_info,
-                            html_info,
+                            res.infotext,
+                            res.gallery,
+                            res.html_info,
+                            res.html_info,
                         ],
                         outputs=[
                             download_files,
-                            html_log,
+                            res.html_log,
                         ]
                     )
 
             else:
-                html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}')
-                html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
-                html_log = gr.HTML(elem_id=f'html_log_{tabname}')
+                res.infotext = gr.HTML(elem_id=f'html_info_x_{tabname}')
+                res.html_info = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes="infotext")
+                res.html_log = gr.HTML(elem_id=f'html_log_{tabname}')
 
             paste_field_names = []
             if tabname == "txt2img":
@@ -220,11 +233,11 @@ Requested path was: {f}
 
             for paste_tabname, paste_button in buttons.items():
                 parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
-                    paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
+                    paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=res.gallery,
                     paste_field_names=paste_field_names
                 ))
 
-            return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
+    return res
 
 
 def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):

+ 4 - 4
modules/ui_postprocessing.py

@@ -28,7 +28,7 @@ def create_ui():
             toprow.create_inline_toprow_image()
             submit = toprow.submit
 
-            result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
+            output_panel = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
 
     tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
     tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])
@@ -48,9 +48,9 @@ def create_ui():
             *script_inputs
         ],
         outputs=[
-            result_images,
-            html_info_x,
-            html_log,
+            output_panel.gallery,
+            output_panel.infotext,
+            output_panel.html_log,
         ],
         show_progress=False,
     )