Browse Source

attempt to fix memory monitor with multiple CUDA devices

AUTOMATIC 2 years ago
parent
commit
a00cd8b9c1
1 changed files with 8 additions and 4 deletions
  1. 8 4
      modules/memmon.py

+ 8 - 4
modules/memmon.py

@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread):
         self.data = defaultdict(int)
         self.data = defaultdict(int)
 
 
         try:
         try:
-            torch.cuda.mem_get_info()
+            self.cuda_mem_get_info()
             torch.cuda.memory_stats(self.device)
             torch.cuda.memory_stats(self.device)
         except Exception as e:  # AMD or whatever
         except Exception as e:  # AMD or whatever
             print(f"Warning: caught exception '{e}', memory monitor disabled")
             print(f"Warning: caught exception '{e}', memory monitor disabled")
             self.disabled = True
             self.disabled = True
 
 
+    def cuda_mem_get_info(self):
+        index = self.device.index if self.device.index is not None else torch.cuda.current_device()
+        return torch.cuda.mem_get_info(index)
+
     def run(self):
     def run(self):
         if self.disabled:
         if self.disabled:
             return
             return
@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread):
                 self.run_flag.clear()
                 self.run_flag.clear()
                 continue
                 continue
 
 
-            self.data["min_free"] = torch.cuda.mem_get_info()[0]
+            self.data["min_free"] = self.cuda_mem_get_info()[0]
 
 
             while self.run_flag.is_set():
             while self.run_flag.is_set():
-                free, total = torch.cuda.mem_get_info()  # calling with self.device errors, torch bug?
+                free, total = self.cuda_mem_get_info()
                 self.data["min_free"] = min(self.data["min_free"], free)
                 self.data["min_free"] = min(self.data["min_free"], free)
 
 
                 time.sleep(1 / self.opts.memmon_poll_rate)
                 time.sleep(1 / self.opts.memmon_poll_rate)
@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread):
 
 
     def read(self):
     def read(self):
         if not self.disabled:
         if not self.disabled:
-            free, total = torch.cuda.mem_get_info()
+            free, total = self.cuda_mem_get_info()
             self.data["free"] = free
             self.data["free"] = free
             self.data["total"] = total
             self.data["total"] = total