浏览代码

work on startup profile display

AUTOMATIC 2 年之前
父节点
当前提交
0cc05fc492
共有 9 个文件被更改,包括 159 次插入14 次删除
  1. 2 0
      html/footer.html
  2. 91 0
      javascript/profilerVisualization.js
  3. 1 1
      javascript/ui_settings_hints.js
  4. 3 0
      modules/script_callbacks.py
  5. 2 1
      modules/scripts.py
  6. 42 4
      modules/timer.py
  7. 3 1
      modules/ui.py
  8. 6 2
      style.css
  9. 9 5
      webui.py

+ 2 - 0
html/footer.html

@@ -5,6 +5,8 @@
          • 
         <a href="https://gradio.app">Gradio</a>
          • 
+        <a href="#" onclick="showProfile('./internal/profile-startup'); return false;">Startup profile</a>
+         • 
         <a href="/" onclick="javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false">Reload UI</a>
 </div>
 <br />

+ 91 - 0
javascript/profilerVisualization.js

@@ -0,0 +1,91 @@
+
+function createRow(table, cellName, items) {
+    var tr = document.createElement('tr');
+    var res = [];
+
+    items.forEach(function(x) {
+        var td = document.createElement(cellName);
+        td.textContent = x;
+        tr.appendChild(td);
+        res.push(td);
+    });
+
+    table.appendChild(tr);
+
+    return res;
+}
+
+function showProfile(path, cutoff = 0.0005) {
+    requestGet(path, {}, function(data) {
+        var table = document.createElement('table');
+        table.className = 'popup-table';
+
+        data.records['total'] = data.total;
+        var keys = Object.keys(data.records).sort(function(a, b) {
+            return data.records[b] - data.records[a];
+        });
+        var items = keys.map(function(x) {
+            return {key: x, parts: x.split('/'), time: data.records[x]};
+        });
+        var maxLength = items.reduce(function(a, b) {
+            return Math.max(a, b.parts.length);
+        }, 0);
+
+        var cols = createRow(table, 'th', ['record', 'seconds']);
+        cols[0].colSpan = maxLength;
+
+        function arraysEqual(a, b) {
+            return !(a < b || b < a);
+        }
+
+        var addLevel = function(level, parent) {
+            var matching = items.filter(function(x) {
+                return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent);
+            });
+            var sorted = matching.sort(function(a, b) {
+                return b.time - a.time;
+            });
+            var othersTime = 0;
+            var othersList = [];
+            sorted.forEach(function(x) {
+                if (x.time < cutoff) {
+                    othersTime += x.time;
+                    othersList.push(x.parts[level]);
+                    return;
+                }
+
+                var cells = [];
+                for (var i = 0; i < maxLength; i++) {
+                    cells.push(x.parts[i]);
+                }
+                cells.push(x.time.toFixed(3));
+                var cols = createRow(table, 'td', cells);
+                for (i = 0; i < level; i++) {
+                    cols[i].className = 'muted';
+                }
+
+                addLevel(level + 1, parent.concat([x.parts[level]]));
+            });
+
+            if (othersTime > 0) {
+                var cells = [];
+                for (var i = 0; i < maxLength; i++) {
+                    cells.push(parent[i]);
+                }
+                cells.push(othersTime.toFixed(3));
+                var cols = createRow(table, 'td', cells);
+                for (i = 0; i < level; i++) {
+                    cols[i].className = 'muted';
+                }
+
+                cols[level].textContent = 'others';
+                cols[level].title = othersList.join(", ");
+            }
+        };
+
+        addLevel(0, []);
+
+        popup(table);
+    });
+}
+

+ 1 - 1
javascript/ui_settings_hints.js

@@ -42,7 +42,7 @@ onOptionsChanged(function() {
 function settingsHintsShowQuicksettings() {
     requestGet("./internal/quicksettings-hint", {}, function(data) {
         var table = document.createElement('table');
-        table.className = 'settings-value-table';
+        table.className = 'popup-table';
 
         data.forEach(function(obj) {
             var tr = document.createElement('tr');

+ 3 - 0
modules/script_callbacks.py

@@ -7,6 +7,8 @@ from typing import Optional, Dict, Any
 from fastapi import FastAPI
 from gradio import Blocks
 
+from modules import timer
+
 
 def report_exception(c, job):
     print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
@@ -123,6 +125,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
     for c in callback_map['callbacks_app_started']:
         try:
             c.callback(demo, app)
+            timer.startup_timer.record(c.script)
         except Exception:
             report_exception(c, 'app_started_callback')
 

+ 2 - 1
modules/scripts.py

@@ -6,7 +6,7 @@ from collections import namedtuple
 
 import gradio as gr
 
-from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
+from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, timer
 
 AlwaysVisible = object()
 
@@ -270,6 +270,7 @@ def load_scripts():
         finally:
             sys.path = syspath
             current_basedir = paths.script_path
+            timer.startup_timer.record(scriptfile.filename)
 
     global scripts_txt2img, scripts_img2img, scripts_postproc
 

+ 42 - 4
modules/timer.py

@@ -1,11 +1,30 @@
 import time
 
 
+class TimerSubcategory:
+    def __init__(self, timer, category):
+        self.timer = timer
+        self.category = category
+        self.start = None
+        self.original_base_category = timer.base_category
+
+    def __enter__(self):
+        self.start = time.time()
+        self.timer.base_category = self.original_base_category + self.category + "/"
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        elapsed_for_subcategroy = time.time() - self.start
+        self.timer.base_category = self.original_base_category
+        self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
+        self.timer.record(self.category)
+
+
 class Timer:
     def __init__(self):
         self.start = time.time()
         self.records = {}
         self.total = 0
+        self.base_category = ''
 
     def elapsed(self):
         end = time.time()
@@ -13,18 +32,29 @@ class Timer:
         self.start = end
         return res
 
-    def record(self, category, extra_time=0):
-        e = self.elapsed()
+    def add_time_to_record(self, category, amount):
         if category not in self.records:
             self.records[category] = 0
 
-        self.records[category] += e + extra_time
+        self.records[category] += amount
+
+    def record(self, category, extra_time=0):
+        e = self.elapsed()
+
+        self.add_time_to_record(self.base_category + category, e + extra_time)
+
         self.total += e + extra_time
 
+    def subcategory(self, name):
+        self.elapsed()
+
+        subcat = TimerSubcategory(self, name)
+        return subcat
+
     def summary(self):
         res = f"{self.total:.1f}s"
 
-        additions = [x for x in self.records.items() if x[1] >= 0.1]
+        additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]
         if not additions:
             return res
 
@@ -34,5 +64,13 @@ class Timer:
 
         return res
 
+    def dump(self):
+        return {'total': self.total, 'records': self.records}
+
     def reset(self):
         self.__init__()
+
+
+startup_timer = Timer()
+
+startup_record = None

+ 3 - 1
modules/ui.py

@@ -13,7 +13,7 @@ import numpy as np
 from PIL import Image, PngImagePlugin  # noqa: F401
 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
 
-from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
+from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, timer
 from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
 from modules.paths import script_path, data_path
 
@@ -1901,3 +1901,5 @@ def setup_ui_api(app):
     app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
 
     app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])
+
+    app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"])

+ 6 - 2
style.css

@@ -403,19 +403,23 @@ div#extras_scale_to_tab div.form{
     margin: 0 1.2em;
 }
 
-table.settings-value-table{
+table.popup-table{
     background: white;
     border-collapse: collapse;
     margin: 1em;
     border: 4px solid white;
 }
 
-table.settings-value-table td{
+table.popup-table td{
     padding: 0.4em;
     border: 1px solid #ccc;
     max-width: 36em;
 }
 
+table.popup-table .muted{
+    color: #aaa;
+}
+
 .ui-defaults-none{
     color: #aaa !important;
 }

+ 9 - 5
webui.py

@@ -20,7 +20,7 @@ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not
 
 from modules import paths, timer, import_hook, errors  # noqa: F401
 
-startup_timer = timer.Timer()
+startup_timer = timer.startup_timer
 
 import torch
 import pytorch_lightning   # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
@@ -269,8 +269,8 @@ def initialize_rest(*, reload_script_modules=False):
 
     localization.list_localizations(cmd_opts.localizations_dir)
 
-    modules.scripts.load_scripts()
-    startup_timer.record("load scripts")
+    with startup_timer.subcategory("load scripts"):
+        modules.scripts.load_scripts()
 
     if reload_script_modules:
         for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
@@ -416,9 +416,12 @@ def webui():
 
         ui_extra_networks.add_pages_to_demo(app)
 
-        modules.script_callbacks.app_started_callback(shared.demo, app)
-        startup_timer.record("scripts app_started_callback")
+        startup_timer.record("add APIs")
+
+        with startup_timer.subcategory("app_started_callback"):
+            modules.script_callbacks.app_started_callback(shared.demo, app)
 
+        timer.startup_record = startup_timer.dump()
         print(f"Startup time: {startup_timer.summary()}.")
 
         if cmd_opts.subpath:
@@ -443,6 +446,7 @@ def webui():
             # If we catch a keyboard interrupt, we want to stop the server and exit.
             shared.demo.close()
             break
+
         print('Restarting UI...')
         shared.demo.close()
         time.sleep(0.5)