shared_state.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import datetime
  2. import logging
  3. import threading
  4. import time
  5. from modules import errors, shared, devices
  6. from typing import Optional
  7. log = logging.getLogger(__name__)
  8. class State:
  9. skipped = False
  10. interrupted = False
  11. job = ""
  12. job_no = 0
  13. job_count = 0
  14. processing_has_refined_job_count = False
  15. job_timestamp = '0'
  16. sampling_step = 0
  17. sampling_steps = 0
  18. current_latent = None
  19. current_image = None
  20. current_image_sampling_step = 0
  21. id_live_preview = 0
  22. textinfo = None
  23. time_start = None
  24. server_start = None
  25. _server_command_signal = threading.Event()
  26. _server_command: Optional[str] = None
  27. def __init__(self):
  28. self.server_start = time.time()
  29. @property
  30. def need_restart(self) -> bool:
  31. # Compatibility getter for need_restart.
  32. return self.server_command == "restart"
  33. @need_restart.setter
  34. def need_restart(self, value: bool) -> None:
  35. # Compatibility setter for need_restart.
  36. if value:
  37. self.server_command = "restart"
  38. @property
  39. def server_command(self):
  40. return self._server_command
  41. @server_command.setter
  42. def server_command(self, value: Optional[str]) -> None:
  43. """
  44. Set the server command to `value` and signal that it's been set.
  45. """
  46. self._server_command = value
  47. self._server_command_signal.set()
  48. def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
  49. """
  50. Wait for server command to get set; return and clear the value and signal.
  51. """
  52. if self._server_command_signal.wait(timeout):
  53. self._server_command_signal.clear()
  54. req = self._server_command
  55. self._server_command = None
  56. return req
  57. return None
  58. def request_restart(self) -> None:
  59. self.interrupt()
  60. self.server_command = "restart"
  61. log.info("Received restart request")
  62. def skip(self):
  63. self.skipped = True
  64. log.info("Received skip request")
  65. def interrupt(self):
  66. self.interrupted = True
  67. log.info("Received interrupt request")
  68. def nextjob(self):
  69. if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
  70. self.do_set_current_image()
  71. self.job_no += 1
  72. self.sampling_step = 0
  73. self.current_image_sampling_step = 0
  74. def dict(self):
  75. obj = {
  76. "skipped": self.skipped,
  77. "interrupted": self.interrupted,
  78. "job": self.job,
  79. "job_count": self.job_count,
  80. "job_timestamp": self.job_timestamp,
  81. "job_no": self.job_no,
  82. "sampling_step": self.sampling_step,
  83. "sampling_steps": self.sampling_steps,
  84. }
  85. return obj
  86. def begin(self, job: str = "(unknown)"):
  87. self.sampling_step = 0
  88. self.time_start = time.time()
  89. self.job_count = -1
  90. self.processing_has_refined_job_count = False
  91. self.job_no = 0
  92. self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  93. self.current_latent = None
  94. self.current_image = None
  95. self.current_image_sampling_step = 0
  96. self.id_live_preview = 0
  97. self.skipped = False
  98. self.interrupted = False
  99. self.textinfo = None
  100. self.job = job
  101. devices.torch_gc()
  102. log.info("Starting job %s", job)
  103. def end(self):
  104. duration = time.time() - self.time_start
  105. log.info("Ending job %s (%.2f seconds)", self.job, duration)
  106. self.job = ""
  107. self.job_count = 0
  108. devices.torch_gc()
  109. def set_current_image(self):
  110. """if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
  111. if not shared.parallel_processing_allowed:
  112. return
  113. if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
  114. self.do_set_current_image()
  115. def do_set_current_image(self):
  116. if self.current_latent is None:
  117. return
  118. import modules.sd_samplers
  119. try:
  120. if shared.opts.show_progress_grid:
  121. self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
  122. else:
  123. self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
  124. self.current_image_sampling_step = self.sampling_step
  125. except Exception:
  126. # when switching models during genration, VAE would be on CPU, so creating an image will fail.
  127. # we silently ignore this error
  128. errors.record_exception()
  129. def assign_current_image(self, image):
  130. self.current_image = image
  131. self.id_live_preview += 1