Преглед на файлове

added tool for profiling code

AUTOMATIC1111 преди 1 година
родител
ревизия
57e6d05a43
променени са 5 файла, в които са добавени 78 реда и са изтрити 5 реда
  1. 8 2
      modules/call_queue.py
  2. 3 2
      modules/processing.py
  3. 46 0
      modules/profiling.py
  4. 16 0
      modules/shared_options.py
  5. 5 1
      style.css

+ 8 - 2
modules/call_queue.py

@@ -1,8 +1,9 @@
+import os.path
 from functools import wraps
 import html
 import time
 
-from modules import shared, progress, errors, devices, fifo_lock
+from modules import shared, progress, errors, devices, fifo_lock, profiling
 
 queue_lock = fifo_lock.FIFOLock()
 
@@ -111,8 +112,13 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
         else:
             vram_html = ''
 
+        if shared.opts.profiling_enable and os.path.exists(shared.opts.profiling_filename):
+            profiling_html = f"<p class='profile'> [ <a href='{profiling.webpath()}' download>Profile</a> ] </p>"
+        else:
+            profiling_html = ''
+
         # last item is always HTML
-        res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}</div>"
+        res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}{profiling_html}</div>"
 
         return tuple(res)
 

+ 3 - 2
modules/processing.py

@@ -16,7 +16,7 @@ from skimage import exposure
 from typing import Any
 
 import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
 from modules.rng import slerp # noqa: F401
 from modules.sd_hijack import model_hijack
 from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
@@ -843,7 +843,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
         # backwards compatibility, fix sampler and scheduler if invalid
         sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
 
-        res = process_images_inner(p)
+        with profiling.Profiler():
+            res = process_images_inner(p)
 
     finally:
         sd_models.apply_token_merging(p.sd_model, 0)

+ 46 - 0
modules/profiling.py

@@ -0,0 +1,46 @@
+import torch
+
+from modules import shared, ui_gradio_extensions
+
+
+class Profiler:
+    def __init__(self):
+        if not shared.opts.profiling_enable:
+            self.profiler = None
+            return
+
+        activities = []
+        if "CPU" in shared.opts.profiling_activities:
+            activities.append(torch.profiler.ProfilerActivity.CPU)
+        if "CUDA" in shared.opts.profiling_activities:
+            activities.append(torch.profiler.ProfilerActivity.CUDA)
+
+        if not activities:
+            self.profiler = None
+            return
+
+        self.profiler = torch.profiler.profile(
+            activities=activities,
+            record_shapes=shared.opts.profiling_record_shapes,
+            profile_memory=shared.opts.profiling_profile_memory,
+            with_stack=shared.opts.profiling_with_stack
+        )
+
+    def __enter__(self):
+        if self.profiler:
+            self.profiler.__enter__()
+
+        return self
+
+    def __exit__(self, exc_type, exc, exc_tb):
+        if self.profiler:
+            shared.state.textinfo = "Finishing profile..."
+
+            self.profiler.__exit__(exc_type, exc, exc_tb)
+
+            self.profiler.export_chrome_trace(shared.opts.profiling_filename)
+
+
+def webpath():
+    return ui_gradio_extensions.webpath(shared.opts.profiling_filename)
+

+ 16 - 0
modules/shared_options.py

@@ -129,6 +129,22 @@ options_templates.update(options_section(('system', "System", "system"), {
     "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
 }))
 
+options_templates.update(options_section(('profiler', "Profiler", "system"), {
+    "profiling_explanation": OptionHTML("""
+Those settings allow you to enable torch profiler when generating pictures.
+Profiling allows you to see which code uses how much of computer's resources during generation.
+Each generation writes its own profile to one file, overwriting previous.
+The file can be viewed in <a href="chrome:tracing">Chrome</a>, or on a <a href="https://ui.perfetto.dev/">Perfetto</a> web site.
+Warning: writing profile can take a lot of time, up to 30 seconds, and the file itelf can be around 500MB in size.
+"""),
+    "profiling_enable": OptionInfo(False, "Enable profiling"),
+    "profiling_activities": OptionInfo(["CPU"], "Activities", gr.CheckboxGroup, {"choices": ["CPU", "CUDA"]}),
+    "profiling_record_shapes": OptionInfo(True, "Record shapes"),
+    "profiling_profile_memory": OptionInfo(True, "Profile memory"),
+    "profiling_with_stack": OptionInfo(True, "Include python stack"),
+    "profiling_filename": OptionInfo("trace.json", "Profile filename"),
+}))
+
 options_templates.update(options_section(('API', "API", "system"), {
     "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
     "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),

+ 5 - 1
style.css

@@ -279,7 +279,7 @@ input[type="checkbox"].input-accordion-checkbox{
     display: inline-block;
 }
 
-.html-log .performance p.time, .performance p.vram, .performance p.time abbr, .performance p.vram abbr {
+.html-log .performance p.time, .performance p.vram, .performance p.profile, .performance p.time abbr, .performance p.vram abbr {
     margin-bottom: 0;
     color: var(--block-title-text-color);
 }
@@ -291,6 +291,10 @@ input[type="checkbox"].input-accordion-checkbox{
     margin-left: auto;
 }
 
+.html-log .performance p.profile {
+    margin-left: 0.5em;
+}
+
 .html-log .performance .measurement{
     color: var(--body-text-color);
     font-weight: bold;