Browse Source

Add some error handling for VRAM monitor

EyeDeck 2 năm trước cách đây
mục cha
commit
fabaf4bddb
2 tập tin đã thay đổi với 31 bổ sung19 xóa
  1. 15 7
      modules/memmon.py
  2. 16 12
      modules/ui.py

+ 15 - 7
modules/memmon.py

@@ -22,6 +22,13 @@ class MemUsageMonitor(threading.Thread):
         self.run_flag = threading.Event()
         self.data = defaultdict(int)
 
+        try:
+            torch.cuda.mem_get_info()
+            torch.cuda.memory_stats(self.device)
+        except Exception as e:  # AMD or whatever
+            print(f"Warning: caught exception '{e}', memory monitor disabled")
+            self.disabled = True
+
     def run(self):
         if self.disabled:
             return
@@ -62,13 +69,14 @@ class MemUsageMonitor(threading.Thread):
         self.run_flag.set()
 
     def read(self):
-        free, total = torch.cuda.mem_get_info()
-        self.data["total"] = total
-
-        torch_stats = torch.cuda.memory_stats(self.device)
-        self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
-        self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
-        self.data["system_peak"] = total - self.data["min_free"]
+        if not self.disabled:
+            free, total = torch.cuda.mem_get_info()
+            self.data["total"] = total
+
+            torch_stats = torch.cuda.memory_stats(self.device)
+            self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
+            self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
+            self.data["system_peak"] = total - self.data["min_free"]
 
         return self.data
 

+ 16 - 12
modules/ui.py

@@ -119,7 +119,8 @@ def save_files(js_data, images, index):
 
 def wrap_gradio_call(func):
     def f(*args, **kwargs):
-        shared.mem_mon.monitor()
+        if opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled:
+            shared.mem_mon.monitor()
         t = time.perf_counter()
 
         try:
@@ -136,17 +137,20 @@ def wrap_gradio_call(func):
 
         elapsed = time.perf_counter() - t
 
-        mem_stats = {k: -(v//-(1024*1024)) for k,v in shared.mem_mon.stop().items()}
-        active_peak = mem_stats['active_peak']
-        reserved_peak = mem_stats['reserved_peak']
-        sys_peak = '?' if opts.memmon_poll_rate <= 0 else mem_stats['system_peak']
-        sys_total = mem_stats['total']
-        sys_pct = '?' if opts.memmon_poll_rate <= 0 else round(sys_peak/sys_total * 100, 2)
-        vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.&#013;" \
-                       "Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.&#013;" \
-                       "Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)."
-
-        vram_html = '' if opts.memmon_poll_rate == 0 else f"<p class='vram' title='{vram_tooltip}'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
+        if opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled:
+            mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}
+            active_peak = mem_stats['active_peak']
+            reserved_peak = mem_stats['reserved_peak']
+            sys_peak = mem_stats['system_peak']
+            sys_total = mem_stats['total']
+            sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2)
+            vram_tooltip = "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.&#013;" \
+                           "Torch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.&#013;" \
+                           "Sys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%)."
+
+            vram_html = f"<p class='vram' title='{vram_tooltip}'>Torch active/reserved: {active_peak}/{reserved_peak} MiB, <wbr>Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)</p>"
+        else:
+            vram_html = ''
 
         # last item is always HTML
         res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"