Эх сурвалжийг харах

big rework of progressbar/preview system to allow multiple users to prompts at the same time and do not get previews of each other

AUTOMATIC 2 жил өмнө
parent
commit
d8b90ac121

+ 159 - 90
javascript/progressbar.js

@@ -1,82 +1,25 @@
 // code related to showing and updating progressbar shown as the image is being made
-global_progressbars = {}
-galleries = {}
-galleryObservers = {}
-
-// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
-timeoutIds = {}
 
-function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
-    // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
-    // every time you use gr.HTML(elem_id='xxx'), so we handle this here
-    var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
-    var progressbarParent
-    if(progressbar){
-        progressbarParent = gradioApp().querySelector("#"+id_progressbar)
-    } else{
-        progressbar = gradioApp().getElementById(id_progressbar)
-        progressbarParent = null
-    }
 
-    var skip = id_skip ? gradioApp().getElementById(id_skip) : null
-    var interrupt = gradioApp().getElementById(id_interrupt)
-
-    if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
-        if(progressbar.innerText){
-            let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion';
-            if(document.title != newtitle){
-                document.title =  newtitle;
-            }
-        }else{
-            let newtitle = 'Stable Diffusion'
-            if(document.title != newtitle){
-                document.title =  newtitle;
-            }
-        }
-    }
-
-	if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){
-	    global_progressbars[id_progressbar] = progressbar
-
-        var mutationObserver = new MutationObserver(function(m){
-            if(timeoutIds[id_part]) return;
-
-            preview = gradioApp().getElementById(id_preview)
-            gallery = gradioApp().getElementById(id_gallery)
+galleries = {}
+storedGallerySelections = {}
+galleryObservers = {}
 
-            if(preview != null && gallery != null){
-                preview.style.width = gallery.clientWidth + "px"
-                preview.style.height = gallery.clientHeight + "px"
-                if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
+function rememberGallerySelection(id_gallery){
+    storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery)
+}
 
-				//only watch gallery if there is a generation process going on
-                check_gallery(id_gallery);
+function getGallerySelectedIndex(id_gallery){
+    let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
+    let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
 
-                var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
-                if(progressDiv){
-                    timeoutIds[id_part] = window.setTimeout(function() {
-                        timeoutIds[id_part] = null
-                        requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
-                    }, 500)
-                } else{
-                    if (skip) {
-                        skip.style.display = "none"
-                    }
-                    interrupt.style.display = "none"
+     let currentlySelectedIndex = -1
+     galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } })
 
-                    //disconnect observer once generation finished, so user can close selected image if they want
-                    if (galleryObservers[id_gallery]) {
-                        galleryObservers[id_gallery].disconnect();
-                        galleries[id_gallery] = null;
-                    }
-                }
-            }
-
-        });
-        mutationObserver.observe( progressbar, { childList:true, subtree:true })
-	}
+     return currentlySelectedIndex
 }
 
+// this is a workaround for https://github.com/gradio-app/gradio/issues/2984
 function check_gallery(id_gallery){
     let gallery = gradioApp().getElementById(id_gallery)
     // if gallery has no change, no need to setting up observer again.
@@ -85,10 +28,16 @@ function check_gallery(id_gallery){
         if(galleryObservers[id_gallery]){
             galleryObservers[id_gallery].disconnect();
         }
-        let prevSelectedIndex = selected_gallery_index();
+
+        storedGallerySelections[id_gallery] = -1
+
         galleryObservers[id_gallery] = new MutationObserver(function (){
             let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
             let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
+            let currentlySelectedIndex = getGallerySelectedIndex(id_gallery)
+            prevSelectedIndex = storedGallerySelections[id_gallery]
+            storedGallerySelections[id_gallery] = -1
+
             if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
                 // automatically re-open previously selected index (if exists)
                 activeElement = gradioApp().activeElement;
@@ -120,30 +69,150 @@ function check_gallery(id_gallery){
 }
 
 onUiUpdate(function(){
-    check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
-    check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
-    check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
+    check_gallery('txt2img_gallery')
+    check_gallery('img2img_gallery')
 })
 
-function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
-    btn = gradioApp().getElementById(id_part+"_check_progress");
-    if(btn==null) return;
-
-    btn.click();
-    var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
-    var skip = id_skip ? gradioApp().getElementById(id_skip) : null
-    var interrupt = gradioApp().getElementById(id_interrupt)
-    if(progressDiv && interrupt){
-        if (skip) {
-            skip.style.display = "block"
+function request(url, data, handler, errorHandler){
+    var xhr = new XMLHttpRequest();
+    var url = url;
+    xhr.open("POST", url, true);
+    xhr.setRequestHeader("Content-Type", "application/json");
+    xhr.onreadystatechange = function () {
+        if (xhr.readyState === 4) {
+            if (xhr.status === 200) {
+                var js = JSON.parse(xhr.responseText);
+                handler(js)
+            } else{
+                errorHandler()
+            }
         }
-        interrupt.style.display = "block"
+    };
+    var js = JSON.stringify(data);
+    xhr.send(js);
+}
+
+function pad2(x){
+    return x<10 ? '0'+x : x
+}
+
+function formatTime(secs){
+    if(secs > 3600){
+        return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60)
+    } else if(secs > 60){
+        return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60)
+    } else{
+        return Math.floor(secs) + "s"
     }
 }
 
-function requestProgress(id_part){
-    btn = gradioApp().getElementById(id_part+"_check_progress_initial");
-    if(btn==null) return;
+function randomId(){
+    return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")"
+}
+
+// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
+// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
+// calls onProgress every time there is a progress update
+function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){
+    var dateStart = new Date()
+    var wasEverActive = false
+    var parentProgressbar = progressbarContainer.parentNode
+    var parentGallery = gallery.parentNode
+
+    var divProgress = document.createElement('div')
+    divProgress.className='progressDiv'
+    var divInner = document.createElement('div')
+    divInner.className='progress'
+
+    divProgress.appendChild(divInner)
+    parentProgressbar.insertBefore(divProgress, progressbarContainer)
+
+    var livePreview = document.createElement('div')
+    livePreview.className='livePreview'
+    parentGallery.insertBefore(livePreview, gallery)
+
+    var removeProgressBar = function(){
+        parentProgressbar.removeChild(divProgress)
+        parentGallery.removeChild(livePreview)
+        atEnd()
+    }
+
+    var fun = function(id_task, id_live_preview){
+        request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){
+            console.log(res)
+
+            if(res.completed){
+                removeProgressBar()
+                return
+            }
+
+            var rect = progressbarContainer.getBoundingClientRect()
+
+            if(rect.width){
+                divProgress.style.width = rect.width + "px";
+            }
+
+            progressText = ""
+
+            divInner.style.width = ((res.progress || 0) * 100.0) + '%'
+
+            if(res.progress > 0){
+                progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%'
+            }
+
+            if(res.eta){
+                progressText += " ETA: " + formatTime(res.eta)
+            } else if(res.textinfo){
+                progressText += " " + res.textinfo
+            }
+
+            divInner.textContent = progressText
+
+            var elapsedFromStart = (new Date() - dateStart) / 1000
+
+            if(res.active) wasEverActive = true;
+
+            if(! res.active && wasEverActive){
+                removeProgressBar()
+                return
+            }
+
+            if(elapsedFromStart > 5 && !res.queued && !res.active){
+                removeProgressBar()
+                return
+            }
+
+
+            if(res.live_preview){
+                var img = new Image();
+                img.onload = function() {
+                    var rect = gallery.getBoundingClientRect()
+                    if(rect.width){
+                        livePreview.style.width = rect.width + "px"
+                        livePreview.style.height = rect.height + "px"
+                    }
+
+                    livePreview.innerHTML = ''
+                    livePreview.appendChild(img)
+                    if(livePreview.childElementCount > 2){
+                        livePreview.removeChild(livePreview.firstElementChild)
+                    }
+                }
+                img.src = res.live_preview;
+            }
+
+
+            if(onProgress){
+                onProgress(res)
+            }
+
+            setTimeout(() => {
+                fun(id_task, res.id_live_preview);
+            }, 500)
+        }, function(){
+            removeProgressBar()
+        })
+    }
 
-    btn.click();
+    fun(id_task, 0)
 }

+ 11 - 2
javascript/textualInversion.js

@@ -1,8 +1,17 @@
 
 
+
 function start_training_textual_inversion(){
-    requestProgress('ti')
     gradioApp().querySelector('#ti_error').innerHTML=''
 
-    return args_to_array(arguments)
+    var id = randomId()
+    requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){
+        gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo
+    })
+
+    var res = args_to_array(arguments)
+
+    res[0] = id
+
+    return res
 }

+ 28 - 5
javascript/ui.js

@@ -126,18 +126,41 @@ function create_submit_args(args){
     return res
 }
 
+function showSubmitButtons(tabname, show){
+    gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block"
+    gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
+}
+
 function submit(){
-    requestProgress('txt2img')
+    rememberGallerySelection('txt2img_gallery')
+    showSubmitButtons('txt2img', false)
+
+    var id = randomId()
+    requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
+        showSubmitButtons('txt2img', true)
+
+    })
 
-    return create_submit_args(arguments)
+    var res = create_submit_args(arguments)
+
+    res[0] = id
+
+    return res
 }
 
 function submit_img2img(){
-    requestProgress('img2img')
+    rememberGallerySelection('img2img_gallery')
+    showSubmitButtons('img2img', false)
+
+    var id = randomId()
+    requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
+        showSubmitButtons('img2img', true)
+    })
 
-    res = create_submit_args(arguments)
+    var res = create_submit_args(arguments)
 
-    res[0] = get_tab_index('mode_img2img')
+    res[0] = id
+    res[1] = get_tab_index('mode_img2img')
 
     return res
 }

+ 15 - 4
modules/call_queue.py

@@ -4,7 +4,7 @@ import threading
 import traceback
 import time
 
-from modules import shared
+from modules import shared, progress
 
 queue_lock = threading.Lock()
 
@@ -22,12 +22,23 @@ def wrap_queued_call(func):
 def wrap_gradio_gpu_call(func, extra_outputs=None):
     def f(*args, **kwargs):
 
-        shared.state.begin()
+        # if the first argument is a string that says "task(...)", it is treated as a job id
+        if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")":
+            id_task = args[0]
+            progress.add_task_to_queue(id_task)
+        else:
+            id_task = None
 
         with queue_lock:
-            res = func(*args, **kwargs)
+            shared.state.begin()
+            progress.start_task(id_task)
+
+            try:
+                res = func(*args, **kwargs)
+            finally:
+                progress.finish_task(id_task)
 
-        shared.state.end()
+            shared.state.end()
 
         return res
 

+ 3 - 3
modules/hypernetworks/hypernetwork.py

@@ -453,7 +453,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
     shared.reload_hypernetworks()
 
 
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
     # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
     from modules import images
 
@@ -629,7 +629,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
 
                 description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
                 pbar.set_description(description)
-                shared.state.textinfo = description
                 if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
                     # Before saving, change name to match current checkpoint.
                     hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
@@ -701,7 +700,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
                         torch.cuda.set_rng_state_all(cuda_rng_state)
                     hypernetwork.train()
                     if image is not None:
-                        shared.state.current_image = image
+                        shared.state.assign_current_image(image)
+
                         last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
                         last_saved_image += f", prompt: {preview_text}"
 

+ 1 - 1
modules/img2img.py

@@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args):
                 processed_image.save(os.path.join(output_dir, filename))
 
 
-def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
+def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
     is_batch = mode == 5
 
     if mode == 0:  # img2img

+ 96 - 0
modules/progress.py

@@ -0,0 +1,96 @@
+import base64
+import io
+import time
+
+import gradio as gr
+from pydantic import BaseModel, Field
+
+from modules.shared import opts
+
+import modules.shared as shared
+
+
+current_task = None
+pending_tasks = {}
+finished_tasks = []
+
+
+def start_task(id_task):
+    global current_task
+
+    current_task = id_task
+    pending_tasks.pop(id_task, None)
+
+
+def finish_task(id_task):
+    global current_task
+
+    if current_task == id_task:
+        current_task = None
+
+    finished_tasks.append(id_task)
+    if len(finished_tasks) > 16:
+        finished_tasks.pop(0)
+
+
+def add_task_to_queue(id_job):
+    pending_tasks[id_job] = time.time()
+
+
+class ProgressRequest(BaseModel):
+    id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
+    id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
+
+
+class ProgressResponse(BaseModel):
+    active: bool = Field(title="Whether the task is being worked on right now")
+    queued: bool = Field(title="Whether the task is in queue")
+    completed: bool = Field(title="Whether the task has already finished")
+    progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
+    eta: float = Field(default=None, title="ETA in secs")
+    live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
+    id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
+    textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
+
+
+def setup_progress_api(app):
+    return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
+
+
+def progressapi(req: ProgressRequest):
+    active = req.id_task == current_task
+    queued = req.id_task in pending_tasks
+    completed = req.id_task in finished_tasks
+
+    if not active:
+        return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...")
+
+    progress = 0
+
+    if shared.state.job_count > 0:
+        progress += shared.state.job_no / shared.state.job_count
+    if shared.state.sampling_steps > 0:
+        progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
+
+    progress = min(progress, 1)
+
+    elapsed_since_start = time.time() - shared.state.time_start
+    predicted_duration = elapsed_since_start / progress if progress > 0 else None
+    eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
+
+    id_live_preview = req.id_live_preview
+    shared.state.set_current_image()
+    if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview:
+        image = shared.state.current_image
+        if image is not None:
+            buffered = io.BytesIO()
+            image.save(buffered, format="png")
+            live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
+            id_live_preview = shared.state.id_live_preview
+        else:
+            live_preview = None
+    else:
+        live_preview = None
+
+    return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
+

+ 1 - 1
modules/sd_samplers.py

@@ -140,7 +140,7 @@ def store_latent(decoded):
 
     if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
         if not shared.parallel_processing_allowed:
-            shared.state.current_image = sample_to_image(decoded)
+            shared.state.assign_current_image(sample_to_image(decoded))
 
 
 class InterruptedException(BaseException):

+ 11 - 5
modules/shared.py

@@ -152,6 +152,7 @@ def reload_hypernetworks():
     hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
 
 
+
 class State:
     skipped = False
     interrupted = False
@@ -165,6 +166,7 @@ class State:
     current_latent = None
     current_image = None
     current_image_sampling_step = 0
+    id_live_preview = 0
     textinfo = None
     time_start = None
     need_restart = False
@@ -207,6 +209,7 @@ class State:
         self.current_latent = None
         self.current_image = None
         self.current_image_sampling_step = 0
+        self.id_live_preview = 0
         self.skipped = False
         self.interrupted = False
         self.textinfo = None
@@ -220,8 +223,8 @@ class State:
 
         devices.torch_gc()
 
-    """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
     def set_current_image(self):
+        """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
         if not parallel_processing_allowed:
             return
 
@@ -234,12 +237,16 @@ class State:
 
         import modules.sd_samplers
         if opts.show_progress_grid:
-            self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent)
+            self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
         else:
-            self.current_image = modules.sd_samplers.sample_to_image(self.current_latent)
+            self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
 
         self.current_image_sampling_step = self.sampling_step
 
+    def assign_current_image(self, image):
+        self.current_image = image
+        self.id_live_preview += 1
+
 
 state = State()
 state.server_start = time.time()
@@ -424,8 +431,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
 }))
 
 options_templates.update(options_section(('ui', "User interface"), {
-    "show_progressbar": OptionInfo(True, "Show progressbar"),
-    "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
     "return_grid": OptionInfo(True, "Show grid in results for web"),
     "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
     "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
@@ -446,6 +451,7 @@ options_templates.update(options_section(('ui', "User interface"), {
 
 options_templates.update(options_section(('ui', "Live previews"), {
     "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
+    "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
     "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
     "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
     "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),

+ 1 - 1
modules/textual_inversion/preprocess.py

@@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts
 from modules.textual_inversion import autocrop
 
 
-def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
+def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False):
     try:
         if process_caption:
             shared.interrogator.load()

+ 3 - 3
modules/textual_inversion/textual_inversion.py

@@ -345,7 +345,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
         assert log_directory, "Log directory is empty"
 
 
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
     save_embedding_every = save_embedding_every or 0
     create_image_every = create_image_every or 0
     template_file = textual_inversion_templates.get(template_filename, None)
@@ -510,7 +510,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
 
                 description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
                 pbar.set_description(description)
-                shared.state.textinfo = description
                 if embedding_dir is not None and steps_done % save_embedding_every == 0:
                     # Before saving, change name to match current checkpoint.
                     embedding_name_every = f'{embedding_name}-{steps_done}'
@@ -560,7 +559,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
                         shared.sd_model.first_stage_model.to(devices.cpu)
 
                     if image is not None:
-                        shared.state.current_image = image
+                        shared.state.assign_current_image(image)
+
                         last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
                         last_saved_image += f", prompt: {preview_text}"
 

+ 1 - 1
modules/txt2img.py

@@ -8,7 +8,7 @@ import modules.processing as processing
 from modules.ui import plaintext_to_html
 
 
-def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args):
     p = StableDiffusionProcessingTxt2Img(
         sd_model=shared.sd_model,
         outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,

+ 11 - 30
modules/ui.py

@@ -356,7 +356,7 @@ def create_toprow(is_img2img):
                 button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
 
         with gr.Column(scale=1):
-            with gr.Row():
+            with gr.Row(elem_id=f"{id_part}_generate_box"):
                 skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
                 interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
                 submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
@@ -384,9 +384,7 @@ def create_toprow(is_img2img):
 
 
 def setup_progressbar(*args, **kwargs):
-    import modules.ui_progress
-
-    modules.ui_progress.setup_progressbar(*args, **kwargs)
+    pass
 
 
 def apply_setting(key, value):
@@ -479,8 +477,8 @@ Requested path was: {f}
             else:
                 sp.Popen(["xdg-open", path])
 
-    with gr.Column(variant='panel'):
-            with gr.Group():
+    with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
+            with gr.Group(elem_id=f"{tabname}_gallery_container"):
                 result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
 
             generation_info = None
@@ -595,15 +593,6 @@ def create_ui():
         dummy_component = gr.Label(visible=False)
         txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False)
 
-        with gr.Row(elem_id='txt2img_progress_row'):
-            with gr.Column(scale=1):
-                pass
-
-            with gr.Column(scale=1):
-                progressbar = gr.HTML(elem_id="txt2img_progressbar")
-                txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
-                setup_progressbar(progressbar, txt2img_preview, 'txt2img')
-
         with gr.Row().style(equal_height=False):
             with gr.Column(variant='panel', elem_id="txt2img_settings"):
                 for category in ordered_ui_categories():
@@ -682,6 +671,7 @@ def create_ui():
                 fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
                 _js="submit",
                 inputs=[
+                    dummy_component,
                     txt2img_prompt,
                     txt2img_negative_prompt,
                     txt2img_prompt_style,
@@ -782,16 +772,7 @@ def create_ui():
     with gr.Blocks(analytics_enabled=False) as img2img_interface:
         img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True)
 
-        with gr.Row(elem_id='img2img_progress_row'):
-            img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
-
-            with gr.Column(scale=1):
-                pass
-
-            with gr.Column(scale=1):
-                progressbar = gr.HTML(elem_id="img2img_progressbar")
-                img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
-                setup_progressbar(progressbar, img2img_preview, 'img2img')
+        img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False)
 
         with FormRow().style(equal_height=False):
             with gr.Column(variant='panel', elem_id="img2img_settings"):
@@ -958,6 +939,7 @@ def create_ui():
                 fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),
                 _js="submit_img2img",
                 inputs=[
+                    dummy_component,
                     dummy_component,
                     img2img_prompt,
                     img2img_negative_prompt,
@@ -1335,15 +1317,11 @@ def create_ui():
 
                 script_callbacks.ui_train_tabs_callback(params)
 
-            with gr.Column():
-                progressbar = gr.HTML(elem_id="ti_progressbar")
+            with gr.Column(elem_id='ti_gallery_container'):
                 ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
-
                 ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
-                ti_preview = gr.Image(elem_id='ti_preview', visible=False)
                 ti_progress = gr.HTML(elem_id="ti_progress", value="")
                 ti_outcome = gr.HTML(elem_id="ti_error", value="")
-                setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
 
         create_embedding.click(
             fn=modules.textual_inversion.ui.create_embedding,
@@ -1384,6 +1362,7 @@ def create_ui():
             fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]),
             _js="start_training_textual_inversion",
             inputs=[
+                dummy_component,
                 process_src,
                 process_dst,
                 process_width,
@@ -1411,6 +1390,7 @@ def create_ui():
             fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
             _js="start_training_textual_inversion",
             inputs=[
+                dummy_component,
                 train_embedding_name,
                 embedding_learn_rate,
                 batch_size,
@@ -1443,6 +1423,7 @@ def create_ui():
             fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),
             _js="start_training_textual_inversion",
             inputs=[
+                dummy_component,
                 train_hypernetwork_name,
                 hypernetwork_learn_rate,
                 batch_size,

+ 0 - 101
modules/ui_progress.py

@@ -1,101 +0,0 @@
-import time
-
-import gradio as gr
-
-from modules.shared import opts
-
-import modules.shared as shared
-
-
-def calc_time_left(progress, threshold, label, force_display, show_eta):
-    if progress == 0:
-        return ""
-    else:
-        time_since_start = time.time() - shared.state.time_start
-        eta = (time_since_start/progress)
-        eta_relative = eta-time_since_start
-        if (eta_relative > threshold and show_eta) or force_display:
-            if eta_relative > 3600:
-                return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
-            elif eta_relative > 60:
-                return label + time.strftime('%M:%S',  time.gmtime(eta_relative))
-            else:
-                return label + time.strftime('%Ss',  time.gmtime(eta_relative))
-        else:
-            return ""
-
-
-def check_progress_call(id_part):
-    if shared.state.job_count == 0:
-        return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
-
-    progress = 0
-
-    if shared.state.job_count > 0:
-        progress += shared.state.job_no / shared.state.job_count
-    if shared.state.sampling_steps > 0:
-        progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
-
-    # Show progress percentage and time left at the same moment, and base it also on steps done
-    show_eta = progress >= 0.01 or shared.state.sampling_step >= 10
-
-    time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta)
-    if time_left != "":
-        shared.state.time_left_force_display = True
-
-    progress = min(progress, 1)
-
-    progressbar = ""
-    if opts.show_progressbar:
-        progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:visible;width:{progress * 100}%;white-space:nowrap;">{"&nbsp;" * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}</div></div>"""
-
-    image = gr.update(visible=False)
-    preview_visibility = gr.update(visible=False)
-
-    if opts.live_previews_enable:
-        shared.state.set_current_image()
-        image = shared.state.current_image
-
-        if image is None:
-            image = gr.update(value=None)
-        else:
-            preview_visibility = gr.update(visible=True)
-
-    if shared.state.textinfo is not None:
-        textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
-    else:
-        textinfo_result = gr.update(visible=False)
-
-    return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
-
-
-def check_progress_call_initial(id_part):
-    shared.state.job_count = -1
-    shared.state.current_latent = None
-    shared.state.current_image = None
-    shared.state.textinfo = None
-    shared.state.time_start = time.time()
-    shared.state.time_left_force_display = False
-
-    return check_progress_call(id_part)
-
-
-def setup_progressbar(progressbar, preview, id_part, textinfo=None):
-    if textinfo is None:
-        textinfo = gr.HTML(visible=False)
-
-    check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
-    check_progress.click(
-        fn=lambda: check_progress_call(id_part),
-        show_progress=False,
-        inputs=[],
-        outputs=[progressbar, preview, preview, textinfo],
-    )
-
-    check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
-    check_progress_initial.click(
-        fn=lambda: check_progress_call_initial(id_part),
-        show_progress=False,
-        inputs=[],
-        outputs=[progressbar, preview, preview, textinfo],
-    )

+ 46 - 28
style.css

@@ -305,26 +305,42 @@ input[type="range"]{
 }
 
 .progressDiv{
-  width: 100%;
-  height: 20px;
-  background: #b4c0cc;
-  border-radius: 8px;
+    position: absolute;
+    height: 20px;
+    top: -20px;
+    background: #b4c0cc;
+    border-radius: 8px !important;
 }
 
 .dark .progressDiv{
-  background: #424c5b;
+    background: #424c5b;
 }
 
 .progressDiv .progress{
-  width: 0%;
-  height: 20px;
-  background: #0060df;
-  color: white;
-  font-weight: bold;
-  line-height: 20px;
-  padding: 0 8px 0 0;
-  text-align: right;
-  border-radius: 8px;
+    width: 0%;
+    height: 20px;
+    background: #0060df;
+    color: white;
+    font-weight: bold;
+    line-height: 20px;
+    padding: 0 8px 0 0;
+    text-align: right;
+    border-radius: 8px;
+    overflow: visible;
+    white-space: nowrap;
+}
+
+.livePreview{
+    position: absolute;
+    z-index: 300;
+    background-color: white;
+    margin: -4px;
+}
+
+.livePreview img{
+    object-fit: contain;
+    width: 100%;
+    height: 100%;
 }
 
 #lightboxModal{
@@ -450,23 +466,25 @@ input[type="range"]{
     display:none
 }
 
-#txt2img_interrupt, #img2img_interrupt{
-  position: absolute;
-  width: 50%;
-  height: 72px;
-  background: #b4c0cc;
-  border-radius: 0px;
-  display: none;
+#txt2img_generate_box, #img2img_generate_box{
+    position: relative;
+}
+
+#txt2img_interrupt, #img2img_interrupt, #txt2img_skip, #img2img_skip{
+    position: absolute;
+    width: 50%;
+    height: 100%;
+    background: #b4c0cc;
+    display: none;
 }
 
+#txt2img_interrupt, #img2img_interrupt{
+    right: 0;
+    border-radius: 0 0.5rem 0.5rem 0;
+}
 #txt2img_skip, #img2img_skip{
-  position: absolute;
-  width: 50%;
-  right: 0px;
-  height: 72px;
-  background: #b4c0cc;
-  border-radius: 0px;
-  display: none;
+    left: 0;
+    border-radius: 0.5rem 0 0 0.5rem;
 }
 
 .red {

+ 3 - 0
webui.py

@@ -34,6 +34,7 @@ import modules.sd_vae
 import modules.txt2img
 import modules.script_callbacks
 import modules.textual_inversion.textual_inversion
+import modules.progress
 
 import modules.ui
 from modules import modelloader
@@ -181,6 +182,8 @@ def webui():
 
         app.add_middleware(GZipMiddleware, minimum_size=1000)
 
+        modules.progress.setup_progress_api(app)
+
         if launch_api:
             create_api(app)