|
@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread):
|
|
self.data = defaultdict(int)
|
|
self.data = defaultdict(int)
|
|
|
|
|
|
try:
|
|
try:
|
|
- torch.cuda.mem_get_info()
|
|
|
|
|
|
+ self.cuda_mem_get_info()
|
|
torch.cuda.memory_stats(self.device)
|
|
torch.cuda.memory_stats(self.device)
|
|
except Exception as e: # AMD or whatever
|
|
except Exception as e: # AMD or whatever
|
|
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
|
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
|
self.disabled = True
|
|
self.disabled = True
|
|
|
|
|
|
|
|
+ def cuda_mem_get_info(self):
|
|
|
|
+ index = self.device.index if self.device.index is not None else torch.cuda.current_device()
|
|
|
|
+ return torch.cuda.mem_get_info(index)
|
|
|
|
+
|
|
def run(self):
|
|
def run(self):
|
|
if self.disabled:
|
|
if self.disabled:
|
|
return
|
|
return
|
|
@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread):
|
|
self.run_flag.clear()
|
|
self.run_flag.clear()
|
|
continue
|
|
continue
|
|
|
|
|
|
- self.data["min_free"] = torch.cuda.mem_get_info()[0]
|
|
|
|
|
|
+ self.data["min_free"] = self.cuda_mem_get_info()[0]
|
|
|
|
|
|
while self.run_flag.is_set():
|
|
while self.run_flag.is_set():
|
|
- free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
|
|
|
|
|
|
+ free, total = self.cuda_mem_get_info()
|
|
self.data["min_free"] = min(self.data["min_free"], free)
|
|
self.data["min_free"] = min(self.data["min_free"], free)
|
|
|
|
|
|
time.sleep(1 / self.opts.memmon_poll_rate)
|
|
time.sleep(1 / self.opts.memmon_poll_rate)
|
|
@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread):
|
|
|
|
|
|
def read(self):
|
|
def read(self):
|
|
if not self.disabled:
|
|
if not self.disabled:
|
|
- free, total = torch.cuda.mem_get_info()
|
|
|
|
|
|
+ free, total = self.cuda_mem_get_info()
|
|
self.data["free"] = free
|
|
self.data["free"] = free
|
|
self.data["total"] = total
|
|
self.data["total"] = total
|
|
|
|
|