profiling.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import torch
  2. from modules import shared, ui_gradio_extensions
  3. class Profiler:
  4. def __init__(self):
  5. if not shared.opts.profiling_enable:
  6. self.profiler = None
  7. return
  8. activities = []
  9. if "CPU" in shared.opts.profiling_activities:
  10. activities.append(torch.profiler.ProfilerActivity.CPU)
  11. if "CUDA" in shared.opts.profiling_activities:
  12. activities.append(torch.profiler.ProfilerActivity.CUDA)
  13. if not activities:
  14. self.profiler = None
  15. return
  16. self.profiler = torch.profiler.profile(
  17. activities=activities,
  18. record_shapes=shared.opts.profiling_record_shapes,
  19. profile_memory=shared.opts.profiling_profile_memory,
  20. with_stack=shared.opts.profiling_with_stack
  21. )
  22. def __enter__(self):
  23. if self.profiler:
  24. self.profiler.__enter__()
  25. return self
  26. def __exit__(self, exc_type, exc, exc_tb):
  27. if self.profiler:
  28. shared.state.textinfo = "Finishing profile..."
  29. self.profiler.__exit__(exc_type, exc, exc_tb)
  30. self.profiler.export_chrome_trace(shared.opts.profiling_filename)
  31. def webpath():
  32. return ui_gradio_extensions.webpath(shared.opts.profiling_filename)