Browse Source

remove Train/Preprocessing tab and put all its functionality into extras batch images mode

AUTOMATIC1111 1 year ago
parent
commit
11d23e8ca5

+ 17 - 0
javascript/ui.js

@@ -170,6 +170,23 @@ function submit_img2img() {
     return res;
     return res;
 }
 }
 
 
+function submit_extras() {
+    showSubmitButtons('extras', false);
+
+    var id = randomId();
+
+    requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() {
+        showSubmitButtons('extras', true);
+    });
+
+    var res = create_submit_args(arguments);
+
+    res[0] = id;
+
+    console.log(res);
+    return res;
+}
+
 function restoreProgressTxt2img() {
 function restoreProgressTxt2img() {
     showRestoreProgressButton("txt2img", false);
     showRestoreProgressButton("txt2img", false);
     var id = localGet("txt2img_task_id");
     var id = localGet("txt2img_task_id");

+ 0 - 15
modules/api/api.py

@@ -22,7 +22,6 @@ from modules.api import models
 from modules.shared import opts
 from modules.shared import opts
 from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
 from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
 from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
 from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
-from modules.textual_inversion.preprocess import preprocess
 from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
 from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
 from PIL import PngImagePlugin, Image
 from PIL import PngImagePlugin, Image
 from modules.sd_models_config import find_checkpoint_config_near_filename
 from modules.sd_models_config import find_checkpoint_config_near_filename
@@ -235,7 +234,6 @@ class Api:
         self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
         self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
         self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
         self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
         self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
         self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
-        self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
         self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
         self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
         self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
         self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
         self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
         self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
@@ -675,19 +673,6 @@ class Api:
         finally:
         finally:
             shared.state.end()
             shared.state.end()
 
 
-    def preprocess(self, args: dict):
-        try:
-            shared.state.begin(job="preprocess")
-            preprocess(**args) # quick operation unless blip/booru interrogation is enabled
-            shared.state.end()
-            return models.PreprocessResponse(info='preprocess complete')
-        except KeyError as e:
-            return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
-        except Exception as e:
-            return models.PreprocessResponse(info=f"preprocess error: {e}")
-        finally:
-            shared.state.end()
-
     def train_embedding(self, args: dict):
     def train_embedding(self, args: dict):
         try:
         try:
             shared.state.begin(job="train_embedding")
             shared.state.begin(job="train_embedding")

+ 0 - 3
modules/api/models.py

@@ -202,9 +202,6 @@ class TrainResponse(BaseModel):
 class CreateResponse(BaseModel):
 class CreateResponse(BaseModel):
     info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
     info: str = Field(title="Create info", description="Response string from create embedding or hypernetwork task.")
 
 
-class PreprocessResponse(BaseModel):
-    info: str = Field(title="Preprocess info", description="Response string from preprocessing task.")
-
 fields = {}
 fields = {}
 for key, metadata in opts.data_labels.items():
 for key, metadata in opts.data_labels.items():
     value = opts.data.get(key)
     value = opts.data.get(key)

+ 69 - 23
modules/postprocessing.py

@@ -6,7 +6,7 @@ from modules import shared, images, devices, scripts, scripts_postprocessing, ui
 from modules.shared import opts
 from modules.shared import opts
 
 
 
 
-def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
+def run_postprocessing(id_task, extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
     devices.torch_gc()
     devices.torch_gc()
 
 
     shared.state.begin(job="extras")
     shared.state.begin(job="extras")
@@ -29,11 +29,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
 
 
             image_list = shared.listfiles(input_dir)
             image_list = shared.listfiles(input_dir)
             for filename in image_list:
             for filename in image_list:
-                try:
-                    image = Image.open(filename)
-                except Exception:
-                    continue
-                yield image, filename
+                yield filename, filename
         else:
         else:
             assert image, 'image not selected'
             assert image, 'image not selected'
             yield image, None
             yield image, None
@@ -45,37 +41,85 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
 
 
     infotext = ''
     infotext = ''
 
 
-    for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
+    data_to_process = list(get_images(extras_mode, image, image_folder, input_dir))
+    shared.state.job_count = len(data_to_process)
+
+    for image_placeholder, name in data_to_process:
         image_data: Image.Image
         image_data: Image.Image
 
 
+        shared.state.nextjob()
         shared.state.textinfo = name
         shared.state.textinfo = name
+        shared.state.skipped = False
+
+        if shared.state.interrupted:
+            break
+
+        if isinstance(image_placeholder, str):
+            try:
+                image_data = Image.open(image_placeholder)
+            except Exception:
+                continue
+        else:
+            image_data = image_placeholder
+
+        shared.state.assign_current_image(image_data)
 
 
         parameters, existing_pnginfo = images.read_info_from_image(image_data)
         parameters, existing_pnginfo = images.read_info_from_image(image_data)
         if parameters:
         if parameters:
             existing_pnginfo["parameters"] = parameters
             existing_pnginfo["parameters"] = parameters
 
 
-        pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
+        initial_pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))
 
 
-        scripts.scripts_postproc.run(pp, args)
+        scripts.scripts_postproc.run(initial_pp, args)
 
 
-        if opts.use_original_name_batch and name is not None:
-            basename = os.path.splitext(os.path.basename(name))[0]
-            forced_filename = basename
-        else:
-            basename = ''
-            forced_filename = None
+        if shared.state.skipped:
+            continue
+
+        used_suffixes = {}
+        for pp in [initial_pp, *initial_pp.extra_images]:
+            suffix = pp.get_suffix(used_suffixes)
+
+            if opts.use_original_name_batch and name is not None:
+                basename = os.path.splitext(os.path.basename(name))[0]
+                forced_filename = basename + suffix
+            else:
+                basename = ''
+                forced_filename = None
+
+            infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
+
+            if opts.enable_pnginfo:
+                pp.image.info = existing_pnginfo
+                pp.image.info["postprocessing"] = infotext
+
+            if save_output:
+                fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)
 
 
-        infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])
+                if pp.caption:
+                    caption_filename = os.path.splitext(fullfn)[0] + ".txt"
+                    if os.path.isfile(caption_filename):
+                        with open(caption_filename, encoding="utf8") as file:
+                            existing_caption = file.read().strip()
+                    else:
+                        existing_caption = ""
 
 
-        if opts.enable_pnginfo:
-            pp.image.info = existing_pnginfo
-            pp.image.info["postprocessing"] = infotext
+                    action = shared.opts.postprocessing_existing_caption_action
+                    if action == 'Prepend' and existing_caption:
+                        caption = f"{existing_caption} {pp.caption}"
+                    elif action == 'Append' and existing_caption:
+                        caption = f"{pp.caption} {existing_caption}"
+                    elif action == 'Keep' and existing_caption:
+                        caption = existing_caption
+                    else:
+                        caption = pp.caption
 
 
-        if save_output:
-            images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=forced_filename)
+                    caption = caption.strip()
+                    if caption:
+                        with open(caption_filename, "w", encoding="utf8") as file:
+                            file.write(caption)
 
 
-        if extras_mode != 2 or show_extras_results:
-            outputs.append(pp.image)
+            if extras_mode != 2 or show_extras_results:
+                outputs.append(pp.image)
 
 
         image_data.close()
         image_data.close()
 
 
@@ -99,9 +143,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
             "upscaler_2_visibility": extras_upscaler_2_visibility,
             "upscaler_2_visibility": extras_upscaler_2_visibility,
         },
         },
         "GFPGAN": {
         "GFPGAN": {
+            "enable": True,
             "gfpgan_visibility": gfpgan_visibility,
             "gfpgan_visibility": gfpgan_visibility,
         },
         },
         "CodeFormer": {
         "CodeFormer": {
+            "enable": True,
             "codeformer_visibility": codeformer_visibility,
             "codeformer_visibility": codeformer_visibility,
             "codeformer_weight": codeformer_weight,
             "codeformer_weight": codeformer_weight,
         },
         },

+ 81 - 5
modules/scripts_postprocessing.py

@@ -1,13 +1,56 @@
+import dataclasses
 import os
 import os
 import gradio as gr
 import gradio as gr
 
 
 from modules import errors, shared
 from modules import errors, shared
 
 
 
 
+@dataclasses.dataclass
+class PostprocessedImageSharedInfo:
+    target_width: int = None
+    target_height: int = None
+
+
 class PostprocessedImage:
 class PostprocessedImage:
     def __init__(self, image):
     def __init__(self, image):
         self.image = image
         self.image = image
         self.info = {}
         self.info = {}
+        self.shared = PostprocessedImageSharedInfo()
+        self.extra_images = []
+        self.nametags = []
+        self.disable_processing = False
+        self.caption = None
+
+    def get_suffix(self, used_suffixes=None):
+        used_suffixes = {} if used_suffixes is None else used_suffixes
+        suffix = "-".join(self.nametags)
+        if suffix:
+            suffix = "-" + suffix
+
+        if suffix not in used_suffixes:
+            used_suffixes[suffix] = 1
+            return suffix
+
+        for i in range(1, 100):
+            proposed_suffix = suffix + "-" + str(i)
+
+            if proposed_suffix not in used_suffixes:
+                used_suffixes[proposed_suffix] = 1
+                return proposed_suffix
+
+        return suffix
+
+    def create_copy(self, new_image, *, nametags=None, disable_processing=False):
+        pp = PostprocessedImage(new_image)
+        pp.shared = self.shared
+        pp.nametags = self.nametags.copy()
+        pp.info = self.info.copy()
+        pp.disable_processing = disable_processing
+
+        if nametags is not None:
+            pp.nametags += nametags
+
+        return pp
 
 
 
 
 class ScriptPostprocessing:
 class ScriptPostprocessing:
@@ -42,10 +85,17 @@ class ScriptPostprocessing:
 
 
         pass
         pass
 
 
-    def image_changed(self):
-        pass
+    def process_firstpass(self, pp: PostprocessedImage, **args):
+        """
+        Called for all scripts before calling process(). Scripts can examine the image here and set fields
+        of the pp object to communicate things to other scripts.
+        args contains a dictionary with all values returned by components from ui()
+        """
 
 
+        pass
 
 
+    def image_changed(self):
+        pass
 
 
 
 
 def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
 def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
@@ -118,16 +168,42 @@ class ScriptPostprocessingRunner:
         return inputs
         return inputs
 
 
     def run(self, pp: PostprocessedImage, args):
     def run(self, pp: PostprocessedImage, args):
-        for script in self.scripts_in_preferred_order():
-            shared.state.job = script.name
+        scripts = []
 
 
+        for script in self.scripts_in_preferred_order():
             script_args = args[script.args_from:script.args_to]
             script_args = args[script.args_from:script.args_to]
 
 
             process_args = {}
             process_args = {}
             for (name, _component), value in zip(script.controls.items(), script_args):
             for (name, _component), value in zip(script.controls.items(), script_args):
                 process_args[name] = value
                 process_args[name] = value
 
 
-            script.process(pp, **process_args)
+            scripts.append((script, process_args))
+
+        for script, process_args in scripts:
+            script.process_firstpass(pp, **process_args)
+
+        all_images = [pp]
+
+        for script, process_args in scripts:
+            if shared.state.skipped:
+                break
+
+            shared.state.job = script.name
+
+            for single_image in all_images.copy():
+
+                if not single_image.disable_processing:
+                    script.process(single_image, **process_args)
+
+                for extra_image in single_image.extra_images:
+                    if not isinstance(extra_image, PostprocessedImage):
+                        extra_image = single_image.create_copy(extra_image)
+
+                    all_images.append(extra_image)
+
+                single_image.extra_images.clear()
+
+        pp.extra_images = all_images[1:]
 
 
     def create_args_for_run(self, scripts_args):
     def create_args_for_run(self, scripts_args):
         if not self.ui_created:
         if not self.ui_created:

+ 1 - 0
modules/shared_options.py

@@ -357,6 +357,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing", "p
     'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
     'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
     'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
     'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}),
     'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
     'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+    'postprocessing_existing_caption_action': OptionInfo("Ignore", "Action for existing captions", gr.Radio, {"choices": ["Ignore", "Keep", "Prepend", "Append"]}).info("when generating captions using postprocessing; Ignore = use generated; Keep = use original; Prepend/Append = combine both"),
 }))
 }))
 
 
 options_templates.update(options_section((None, "Hidden options"), {
 options_templates.update(options_section((None, "Hidden options"), {

+ 0 - 232
modules/textual_inversion/preprocess.py

@@ -1,232 +0,0 @@
-import os
-from PIL import Image, ImageOps
-import math
-import tqdm
-
-from modules import shared, images, deepbooru
-from modules.textual_inversion import autocrop
-
-
-def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.15, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
-    try:
-        if process_caption:
-            shared.interrogator.load()
-
-        if process_caption_deepbooru:
-            deepbooru.model.start()
-
-        preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
-
-    finally:
-
-        if process_caption:
-            shared.interrogator.send_blip_to_ram()
-
-        if process_caption_deepbooru:
-            deepbooru.model.stop()
-
-
-def listfiles(dirname):
-    return os.listdir(dirname)
-
-
-class PreprocessParams:
-    src = None
-    dstdir = None
-    subindex = 0
-    flip = False
-    process_caption = False
-    process_caption_deepbooru = False
-    preprocess_txt_action = None
-
-
-def save_pic_with_caption(image, index, params: PreprocessParams, existing_caption=None):
-    caption = ""
-
-    if params.process_caption:
-        caption += shared.interrogator.generate_caption(image)
-
-    if params.process_caption_deepbooru:
-        if caption:
-            caption += ", "
-        caption += deepbooru.model.tag_multi(image)
-
-    filename_part = params.src
-    filename_part = os.path.splitext(filename_part)[0]
-    filename_part = os.path.basename(filename_part)
-
-    basename = f"{index:05}-{params.subindex}-{filename_part}"
-    image.save(os.path.join(params.dstdir, f"{basename}.png"))
-
-    if params.preprocess_txt_action == 'prepend' and existing_caption:
-        caption = f"{existing_caption} {caption}"
-    elif params.preprocess_txt_action == 'append' and existing_caption:
-        caption = f"{caption} {existing_caption}"
-    elif params.preprocess_txt_action == 'copy' and existing_caption:
-        caption = existing_caption
-
-    caption = caption.strip()
-
-    if caption:
-        with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file:
-            file.write(caption)
-
-    params.subindex += 1
-
-
-def save_pic(image, index, params, existing_caption=None):
-    save_pic_with_caption(image, index, params, existing_caption=existing_caption)
-
-    if params.flip:
-        save_pic_with_caption(ImageOps.mirror(image), index, params, existing_caption=existing_caption)
-
-
-def split_pic(image, inverse_xy, width, height, overlap_ratio):
-    if inverse_xy:
-        from_w, from_h = image.height, image.width
-        to_w, to_h = height, width
-    else:
-        from_w, from_h = image.width, image.height
-        to_w, to_h = width, height
-    h = from_h * to_w // from_w
-    if inverse_xy:
-        image = image.resize((h, to_w))
-    else:
-        image = image.resize((to_w, h))
-
-    split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
-    y_step = (h - to_h) / (split_count - 1)
-    for i in range(split_count):
-        y = int(y_step * i)
-        if inverse_xy:
-            splitted = image.crop((y, 0, y + to_h, to_w))
-        else:
-            splitted = image.crop((0, y, to_w, y + to_h))
-        yield splitted
-
-# not using torchvision.transforms.CenterCrop because it doesn't allow float regions
-def center_crop(image: Image, w: int, h: int):
-    iw, ih = image.size
-    if ih / h < iw / w:
-        sw = w * ih / h
-        box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
-    else:
-        sh = h * iw / w
-        box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
-    return image.resize((w, h), Image.Resampling.LANCZOS, box)
-
-
-def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
-    iw, ih = image.size
-    err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h))
-    wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64)
-        if minarea <= w * h <= maxarea and err(w, h) <= threshold),
-        key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1],
-        default=None
-    )
-    return wh and center_crop(image, *wh)
-
-
-def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
-    width = process_width
-    height = process_height
-    src = os.path.abspath(process_src)
-    dst = os.path.abspath(process_dst)
-    split_threshold = max(0.0, min(1.0, split_threshold))
-    overlap_ratio = max(0.0, min(0.9, overlap_ratio))
-
-    assert src != dst, 'same directory specified as source and destination'
-
-    os.makedirs(dst, exist_ok=True)
-
-    files = listfiles(src)
-
-    shared.state.job = "preprocess"
-    shared.state.textinfo = "Preprocessing..."
-    shared.state.job_count = len(files)
-
-    params = PreprocessParams()
-    params.dstdir = dst
-    params.flip = process_flip
-    params.process_caption = process_caption
-    params.process_caption_deepbooru = process_caption_deepbooru
-    params.preprocess_txt_action = preprocess_txt_action
-
-    pbar = tqdm.tqdm(files)
-    for index, imagefile in enumerate(pbar):
-        params.subindex = 0
-        filename = os.path.join(src, imagefile)
-        try:
-            img = Image.open(filename)
-            img = ImageOps.exif_transpose(img)
-            img = img.convert("RGB")
-        except Exception:
-            continue
-
-        description = f"Preprocessing [Image {index}/{len(files)}]"
-        pbar.set_description(description)
-        shared.state.textinfo = description
-
-        params.src = filename
-
-        existing_caption = None
-        existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
-        if os.path.exists(existing_caption_filename):
-            with open(existing_caption_filename, 'r', encoding="utf8") as file:
-                existing_caption = file.read()
-
-        if shared.state.interrupted:
-            break
-
-        if img.height > img.width:
-            ratio = (img.width * height) / (img.height * width)
-            inverse_xy = False
-        else:
-            ratio = (img.height * width) / (img.width * height)
-            inverse_xy = True
-
-        process_default_resize = True
-
-        if process_split and ratio < 1.0 and ratio <= split_threshold:
-            for splitted in split_pic(img, inverse_xy, width, height, overlap_ratio):
-                save_pic(splitted, index, params, existing_caption=existing_caption)
-            process_default_resize = False
-
-        if process_focal_crop and img.height != img.width:
-
-            dnn_model_path = None
-            try:
-                dnn_model_path = autocrop.download_and_cache_models()
-            except Exception as e:
-                print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
-
-            autocrop_settings = autocrop.Settings(
-                crop_width = width,
-                crop_height = height,
-                face_points_weight = process_focal_crop_face_weight,
-                entropy_points_weight = process_focal_crop_entropy_weight,
-                corner_points_weight = process_focal_crop_edges_weight,
-                annotate_image = process_focal_crop_debug,
-                dnn_model_path = dnn_model_path,
-            )
-            for focal in autocrop.crop_image(img, autocrop_settings):
-                save_pic(focal, index, params, existing_caption=existing_caption)
-            process_default_resize = False
-
-        if process_multicrop:
-            cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
-            if cropped is not None:
-                save_pic(cropped, index, params, existing_caption=existing_caption)
-            else:
-                print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
-            process_default_resize = False
-
-        if process_keep_original_size:
-            save_pic(img, index, params, existing_caption=existing_caption)
-            process_default_resize = False
-
-        if process_default_resize:
-            img = images.resize_image(1, img, width, height)
-            save_pic(img, index, params, existing_caption=existing_caption)
-
-        shared.state.nextjob()

+ 0 - 7
modules/textual_inversion/ui.py

@@ -3,7 +3,6 @@ import html
 import gradio as gr
 import gradio as gr
 
 
 import modules.textual_inversion.textual_inversion
 import modules.textual_inversion.textual_inversion
-import modules.textual_inversion.preprocess
 from modules import sd_hijack, shared
 from modules import sd_hijack, shared
 
 
 
 
@@ -15,12 +14,6 @@ def create_embedding(name, initialization_text, nvpt, overwrite_old):
     return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
     return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
 
 
 
 
-def preprocess(*args):
-    modules.textual_inversion.preprocess.preprocess(*args)
-
-    return f"Preprocessing {'interrupted' if shared.state.interrupted else 'finished'}.", ""
-
-
 def train_embedding(*args):
 def train_embedding(*args):
 
 
     assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
     assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'

+ 0 - 107
modules/ui.py

@@ -912,71 +912,6 @@ def create_ui():
                         with gr.Column():
                         with gr.Column():
                             create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
                             create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
 
 
-                with gr.Tab(label="Preprocess images", id="preprocess_images"):
-                    process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
-                    process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
-                    process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
-                    process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height")
-                    preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
-
-                    with gr.Row():
-                        process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size")
-                        process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
-                        process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
-                        process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
-                        process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop")
-                        process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption")
-                        process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru")
-
-                    with gr.Row(visible=False) as process_split_extra_row:
-                        process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold")
-                        process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio")
-
-                    with gr.Row(visible=False) as process_focal_crop_row:
-                        process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight")
-                        process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
-                        process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
-                        process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
-
-                    with gr.Column(visible=False) as process_multicrop_col:
-                        gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
-                        with gr.Row():
-                            process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim")
-                            process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim")
-                        with gr.Row():
-                            process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea")
-                            process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea")
-                        with gr.Row():
-                            process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
-                            process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
-
-                    with gr.Row():
-                        with gr.Column(scale=3):
-                            gr.HTML(value="")
-
-                        with gr.Column():
-                            with gr.Row():
-                                interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing")
-                            run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess")
-
-                    process_split.change(
-                        fn=lambda show: gr_show(show),
-                        inputs=[process_split],
-                        outputs=[process_split_extra_row],
-                    )
-
-                    process_focal_crop.change(
-                        fn=lambda show: gr_show(show),
-                        inputs=[process_focal_crop],
-                        outputs=[process_focal_crop_row],
-                    )
-
-                    process_multicrop.change(
-                        fn=lambda show: gr_show(show),
-                        inputs=[process_multicrop],
-                        outputs=[process_multicrop_col],
-                    )
-
                 def get_textual_inversion_template_names():
                 def get_textual_inversion_template_names():
                     return sorted(textual_inversion.textual_inversion_templates)
                     return sorted(textual_inversion.textual_inversion_templates)
 
 
@@ -1077,42 +1012,6 @@ def create_ui():
             ]
             ]
         )
         )
 
 
-        run_preprocess.click(
-            fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]),
-            _js="start_training_textual_inversion",
-            inputs=[
-                dummy_component,
-                process_src,
-                process_dst,
-                process_width,
-                process_height,
-                preprocess_txt_action,
-                process_keep_original_size,
-                process_flip,
-                process_split,
-                process_caption,
-                process_caption_deepbooru,
-                process_split_threshold,
-                process_overlap_ratio,
-                process_focal_crop,
-                process_focal_crop_face_weight,
-                process_focal_crop_entropy_weight,
-                process_focal_crop_edges_weight,
-                process_focal_crop_debug,
-                process_multicrop,
-                process_multicrop_mindim,
-                process_multicrop_maxdim,
-                process_multicrop_minarea,
-                process_multicrop_maxarea,
-                process_multicrop_objective,
-                process_multicrop_threshold,
-            ],
-            outputs=[
-                ti_output,
-                ti_outcome,
-            ],
-        )
-
         train_embedding.click(
         train_embedding.click(
             fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
             fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),
             _js="start_training_textual_inversion",
             _js="start_training_textual_inversion",
@@ -1186,12 +1085,6 @@ def create_ui():
             outputs=[],
             outputs=[],
         )
         )
 
 
-        interrupt_preprocessing.click(
-            fn=lambda: shared.state.interrupt(),
-            inputs=[],
-            outputs=[],
-        )
-
     loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
     loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)
 
 
     settings = ui_settings.UiSettings()
     settings = ui_settings.UiSettings()

+ 11 - 5
modules/ui_postprocessing.py

@@ -1,9 +1,10 @@
 import gradio as gr
 import gradio as gr
-from modules import scripts, shared, ui_common, postprocessing, call_queue
+from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow
 import modules.generation_parameters_copypaste as parameters_copypaste
 import modules.generation_parameters_copypaste as parameters_copypaste
 
 
 
 
 def create_ui():
 def create_ui():
+    dummy_component = gr.Label(visible=False)
     tab_index = gr.State(value=0)
     tab_index = gr.State(value=0)
 
 
     with gr.Row(equal_height=False, variant='compact'):
     with gr.Row(equal_height=False, variant='compact'):
@@ -20,11 +21,13 @@ def create_ui():
                     extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
                     extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
                     show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
                     show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
 
 
-            submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
-
             script_inputs = scripts.scripts_postproc.setup_ui()
             script_inputs = scripts.scripts_postproc.setup_ui()
 
 
         with gr.Column():
         with gr.Column():
+            toprow = ui_toprow.Toprow(is_compact=True, is_img2img=False, id_part="extras")
+            toprow.create_inline_toprow_image()
+            submit = toprow.submit
+
             result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
             result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples)
 
 
     tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
     tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])
@@ -33,7 +36,9 @@ def create_ui():
 
 
     submit.click(
     submit.click(
         fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
         fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']),
+        _js="submit_extras",
         inputs=[
         inputs=[
+            dummy_component,
             tab_index,
             tab_index,
             extras_image,
             extras_image,
             image_batch,
             image_batch,
@@ -45,8 +50,9 @@ def create_ui():
         outputs=[
         outputs=[
             result_images,
             result_images,
             html_info_x,
             html_info_x,
-            html_info,
-        ]
+            html_log,
+        ],
+        show_progress=False,
     )
     )
 
 
     parameters_copypaste.add_paste_fields("extras", extras_image, None)
     parameters_copypaste.add_paste_fields("extras", extras_image, None)

+ 4 - 2
modules/ui_toprow.py

@@ -34,8 +34,10 @@ class Toprow:
 
 
     submit_box = None
     submit_box = None
 
 
-    def __init__(self, is_img2img, is_compact=False):
-        id_part = "img2img" if is_img2img else "txt2img"
+    def __init__(self, is_img2img, is_compact=False, id_part=None):
+        if id_part is None:
+            id_part = "img2img" if is_img2img else "txt2img"
+
         self.id_part = id_part
         self.id_part = id_part
         self.is_img2img = is_img2img
         self.is_img2img = is_img2img
         self.is_compact = is_compact
         self.is_compact = is_compact

+ 30 - 0
scripts/postprocessing_caption.py

@@ -0,0 +1,30 @@
+from modules import scripts_postprocessing, ui_components, deepbooru, shared
+import gradio as gr
+
+
+class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing):
+    name = "Caption"
+    order = 4000
+
+    def ui(self):
+        with ui_components.InputAccordion(False, label="Caption") as enable:
+            option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False)
+
+        return {
+            "enable": enable,
+            "option": option,
+        }
+
+    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
+        if not enable:
+            return
+
+        captions = [pp.caption]
+
+        if "Deepbooru" in option:
+            captions.append(deepbooru.model.tag(pp.image))
+
+        if "BLIP" in option:
+            captions.append(shared.interrogator.generate_caption(pp.image))
+
+        pp.caption = ", ".join([x for x in captions if x])

+ 8 - 8
scripts/postprocessing_codeformer.py

@@ -1,28 +1,28 @@
 from PIL import Image
 from PIL import Image
 import numpy as np
 import numpy as np
 
 
-from modules import scripts_postprocessing, codeformer_model
+from modules import scripts_postprocessing, codeformer_model, ui_components
 import gradio as gr
 import gradio as gr
 
 
-from modules.ui_components import FormRow
-
 
 
 class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
 class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):
     name = "CodeFormer"
     name = "CodeFormer"
     order = 3000
     order = 3000
 
 
     def ui(self):
     def ui(self):
-        with FormRow():
-            codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility")
-            codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
+        with ui_components.InputAccordion(False, label="CodeFormer") as enable:
+            with gr.Row():
+                codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_codeformer_visibility")
+                codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight")
 
 
         return {
         return {
+            "enable": enable,
             "codeformer_visibility": codeformer_visibility,
             "codeformer_visibility": codeformer_visibility,
             "codeformer_weight": codeformer_weight,
             "codeformer_weight": codeformer_weight,
         }
         }
 
 
-    def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight):
-        if codeformer_visibility == 0:
+    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, codeformer_visibility, codeformer_weight):
+        if codeformer_visibility == 0 or not enable:
             return
             return
 
 
         restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)
         restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight)

+ 32 - 0
scripts/postprocessing_create_flipped_copies.py

@@ -0,0 +1,32 @@
+from PIL import ImageOps, Image
+
+from modules import scripts_postprocessing, ui_components
+import gradio as gr
+
+
+class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing):
+    name = "Create flipped copies"
+    order = 4000
+
+    def ui(self):
+        with ui_components.InputAccordion(False, label="Create flipped copies") as enable:
+            with gr.Row():
+                option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False)
+
+        return {
+            "enable": enable,
+            "option": option,
+        }
+
+    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):
+        if not enable:
+            return
+
+        if "Horizontal" in option:
+            pp.extra_images.append(ImageOps.mirror(pp.image))
+
+        if "Vertical" in option:
+            pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM))
+
+        if "Both" in option:
+            pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT))

+ 54 - 0
scripts/postprocessing_focal_crop.py

@@ -0,0 +1,54 @@
+
+from modules import scripts_postprocessing, ui_components, errors
+import gradio as gr
+
+from modules.textual_inversion import autocrop
+
+
+class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing):
+    name = "Auto focal point crop"
+    order = 4000
+
+    def ui(self):
+        with ui_components.InputAccordion(False, label="Auto focal point crop") as enable:
+            face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight")
+            entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight")
+            edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight")
+            debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
+
+        return {
+            "enable": enable,
+            "face_weight": face_weight,
+            "entropy_weight": entropy_weight,
+            "edges_weight": edges_weight,
+            "debug": debug,
+        }
+
+    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug):
+        if not enable:
+            return
+
+        if not pp.shared.target_width or not pp.shared.target_height:
+            return
+
+        dnn_model_path = None
+        try:
+            dnn_model_path = autocrop.download_and_cache_models()
+        except Exception:
+            errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True)
+
+        autocrop_settings = autocrop.Settings(
+            crop_width=pp.shared.target_width,
+            crop_height=pp.shared.target_height,
+            face_points_weight=face_weight,
+            entropy_points_weight=entropy_weight,
+            corner_points_weight=edges_weight,
+            annotate_image=debug,
+            dnn_model_path=dnn_model_path,
+        )
+
+        result, *others = autocrop.crop_image(pp.image, autocrop_settings)
+
+        pp.image = result
+        pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others]
+

+ 6 - 7
scripts/postprocessing_gfpgan.py

@@ -1,26 +1,25 @@
 from PIL import Image
 from PIL import Image
 import numpy as np
 import numpy as np
 
 
-from modules import scripts_postprocessing, gfpgan_model
+from modules import scripts_postprocessing, gfpgan_model, ui_components
 import gradio as gr
 import gradio as gr
 
 
-from modules.ui_components import FormRow
-
 
 
 class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
 class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):
     name = "GFPGAN"
     name = "GFPGAN"
     order = 2000
     order = 2000
 
 
     def ui(self):
     def ui(self):
-        with FormRow():
-            gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility")
+        with ui_components.InputAccordion(False, label="GFPGAN") as enable:
+            gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_gfpgan_visibility")
 
 
         return {
         return {
+            "enable": enable,
             "gfpgan_visibility": gfpgan_visibility,
             "gfpgan_visibility": gfpgan_visibility,
         }
         }
 
 
-    def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility):
-        if gfpgan_visibility == 0:
+    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, gfpgan_visibility):
+        if gfpgan_visibility == 0 or not enable:
             return
             return
 
 
         restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))
         restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8))

+ 71 - 0
scripts/postprocessing_split_oversized.py

@@ -0,0 +1,71 @@
+import math
+
+from modules import scripts_postprocessing, ui_components
+import gradio as gr
+
+
+def split_pic(image, inverse_xy, width, height, overlap_ratio):
+    if inverse_xy:
+        from_w, from_h = image.height, image.width
+        to_w, to_h = height, width
+    else:
+        from_w, from_h = image.width, image.height
+        to_w, to_h = width, height
+    h = from_h * to_w // from_w
+    if inverse_xy:
+        image = image.resize((h, to_w))
+    else:
+        image = image.resize((to_w, h))
+
+    split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
+    y_step = (h - to_h) / (split_count - 1)
+    for i in range(split_count):
+        y = int(y_step * i)
+        if inverse_xy:
+            splitted = image.crop((y, 0, y + to_h, to_w))
+        else:
+            splitted = image.crop((0, y, to_w, y + to_h))
+        yield splitted
+
+
+class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing):
+    name = "Split oversized images"
+    order = 4000
+
+    def ui(self):
+        with ui_components.InputAccordion(False, label="Split oversized images") as enable:
+            with gr.Row():
+                split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold")
+                overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio")
+
+        return {
+            "enable": enable,
+            "split_threshold": split_threshold,
+            "overlap_ratio": overlap_ratio,
+        }
+
+    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio):
+        if not enable:
+            return
+
+        width = pp.shared.target_width
+        height = pp.shared.target_height
+
+        if not width or not height:
+            return
+
+        if pp.image.height > pp.image.width:
+            ratio = (pp.image.width * height) / (pp.image.height * width)
+            inverse_xy = False
+        else:
+            ratio = (pp.image.height * width) / (pp.image.width * height)
+            inverse_xy = True
+
+        if ratio >= 1.0 and ratio > split_threshold:
+            return
+
+        result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio)
+
+        pp.image = result
+        pp.extra_images = [pp.create_copy(x) for x in others]
+

+ 12 - 0
scripts/postprocessing_upscale.py

@@ -81,6 +81,14 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
 
 
         return image
         return image
 
 
+    def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
+        if upscale_mode == 1:
+            pp.shared.target_width = upscale_to_width
+            pp.shared.target_height = upscale_to_height
+        else:
+            pp.shared.target_width = int(pp.image.width * upscale_by)
+            pp.shared.target_height = int(pp.image.height * upscale_by)
+
     def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
     def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):
         if upscaler_1_name == "None":
         if upscaler_1_name == "None":
             upscaler_1_name = None
             upscaler_1_name = None
@@ -126,6 +134,10 @@ class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):
             "upscaler_name": upscaler_name,
             "upscaler_name": upscaler_name,
         }
         }
 
 
+    def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
+        pp.shared.target_width = int(pp.image.width * upscale_by)
+        pp.shared.target_height = int(pp.image.height * upscale_by)
+
     def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
     def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):
         if upscaler_name is None or upscaler_name == "None":
         if upscaler_name is None or upscaler_name == "None":
             return
             return

+ 64 - 0
scripts/processing_autosized_crop.py

@@ -0,0 +1,64 @@
+from PIL import Image
+
+from modules import scripts_postprocessing, ui_components
+import gradio as gr
+
+
+def center_crop(image: Image, w: int, h: int):
+    iw, ih = image.size
+    if ih / h < iw / w:
+        sw = w * ih / h
+        box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih
+    else:
+        sh = h * iw / w
+        box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2
+    return image.resize((w, h), Image.Resampling.LANCZOS, box)
+
+
+def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):
+    iw, ih = image.size
+    err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h))
+    wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64)
+              if minarea <= w * h <= maxarea and err(w, h) <= threshold),
+             key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1],
+             default=None
+             )
+    return wh and center_crop(image, *wh)
+
+
+class ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing):
+    name = "Auto-sized crop"
+    order = 4000
+
+    def ui(self):
+        with ui_components.InputAccordion(False, label="Auto-sized crop") as enable:
+            gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
+            with gr.Row():
+                mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="postprocess_multicrop_mindim")
+                maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="postprocess_multicrop_maxdim")
+            with gr.Row():
+                minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area lower bound", value=64 * 64, elem_id="postprocess_multicrop_minarea")
+                maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label="Area upper bound", value=640 * 640, elem_id="postprocess_multicrop_maxarea")
+            with gr.Row():
+                objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="postprocess_multicrop_objective")
+                threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="postprocess_multicrop_threshold")
+
+        return {
+            "enable": enable,
+            "mindim": mindim,
+            "maxdim": maxdim,
+            "minarea": minarea,
+            "maxarea": maxarea,
+            "objective": objective,
+            "threshold": threshold,
+        }
+
+    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold):
+        if not enable:
+            return
+
+        cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold)
+        if cropped is not None:
+            pp.image = cropped
+        else:
+            print(f"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)")