shared_state.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import datetime
  2. import logging
  3. import threading
  4. import time
  5. from modules import errors, shared, devices
  6. from typing import Optional
  7. log = logging.getLogger(__name__)
  8. class State:
  9. skipped = False
  10. interrupted = False
  11. stopping_generation = False
  12. job = ""
  13. job_no = 0
  14. job_count = 0
  15. processing_has_refined_job_count = False
  16. job_timestamp = '0'
  17. sampling_step = 0
  18. sampling_steps = 0
  19. current_latent = None
  20. current_image = None
  21. current_image_sampling_step = 0
  22. id_live_preview = 0
  23. textinfo = None
  24. time_start = None
  25. server_start = None
  26. _server_command_signal = threading.Event()
  27. _server_command: Optional[str] = None
  28. def __init__(self):
  29. self.server_start = time.time()
  30. @property
  31. def need_restart(self) -> bool:
  32. # Compatibility getter for need_restart.
  33. return self.server_command == "restart"
  34. @need_restart.setter
  35. def need_restart(self, value: bool) -> None:
  36. # Compatibility setter for need_restart.
  37. if value:
  38. self.server_command = "restart"
  39. @property
  40. def server_command(self):
  41. return self._server_command
  42. @server_command.setter
  43. def server_command(self, value: Optional[str]) -> None:
  44. """
  45. Set the server command to `value` and signal that it's been set.
  46. """
  47. self._server_command = value
  48. self._server_command_signal.set()
  49. def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
  50. """
  51. Wait for server command to get set; return and clear the value and signal.
  52. """
  53. if self._server_command_signal.wait(timeout):
  54. self._server_command_signal.clear()
  55. req = self._server_command
  56. self._server_command = None
  57. return req
  58. return None
  59. def request_restart(self) -> None:
  60. self.interrupt()
  61. self.server_command = "restart"
  62. log.info("Received restart request")
  63. def skip(self):
  64. self.skipped = True
  65. log.info("Received skip request")
  66. def interrupt(self):
  67. self.interrupted = True
  68. log.info("Received interrupt request")
  69. def stop_generating(self):
  70. self.stopping_generation = True
  71. log.info("Received stop generating request")
  72. def nextjob(self):
  73. if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
  74. self.do_set_current_image()
  75. self.job_no += 1
  76. self.sampling_step = 0
  77. self.current_image_sampling_step = 0
  78. def dict(self):
  79. obj = {
  80. "skipped": self.skipped,
  81. "interrupted": self.interrupted,
  82. "stopping_generation": self.stopping_generation,
  83. "job": self.job,
  84. "job_count": self.job_count,
  85. "job_timestamp": self.job_timestamp,
  86. "job_no": self.job_no,
  87. "sampling_step": self.sampling_step,
  88. "sampling_steps": self.sampling_steps,
  89. }
  90. return obj
  91. def begin(self, job: str = "(unknown)"):
  92. self.sampling_step = 0
  93. self.time_start = time.time()
  94. self.job_count = -1
  95. self.processing_has_refined_job_count = False
  96. self.job_no = 0
  97. self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
  98. self.current_latent = None
  99. self.current_image = None
  100. self.current_image_sampling_step = 0
  101. self.id_live_preview = 0
  102. self.skipped = False
  103. self.interrupted = False
  104. self.stopping_generation = False
  105. self.textinfo = None
  106. self.job = job
  107. devices.torch_gc()
  108. log.info("Starting job %s", job)
  109. def end(self):
  110. duration = time.time() - self.time_start
  111. log.info("Ending job %s (%.2f seconds)", self.job, duration)
  112. self.job = ""
  113. self.job_count = 0
  114. devices.torch_gc()
  115. def set_current_image(self):
  116. """if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
  117. if not shared.parallel_processing_allowed:
  118. return
  119. if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
  120. self.do_set_current_image()
  121. def do_set_current_image(self):
  122. if self.current_latent is None:
  123. return
  124. import modules.sd_samplers
  125. try:
  126. if shared.opts.show_progress_grid:
  127. self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
  128. else:
  129. self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
  130. self.current_image_sampling_step = self.sampling_step
  131. except Exception:
  132. # when switching models during genration, VAE would be on CPU, so creating an image will fail.
  133. # we silently ignore this error
  134. errors.record_exception()
  135. def assign_current_image(self, image):
  136. self.current_image = image
  137. self.id_live_preview += 1