123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- import datetime
- import logging
- import threading
- import time
- from modules import errors, shared, devices
- from typing import Optional
- log = logging.getLogger(__name__)
- class State:
- skipped = False
- interrupted = False
- job = ""
- job_no = 0
- job_count = 0
- processing_has_refined_job_count = False
- job_timestamp = '0'
- sampling_step = 0
- sampling_steps = 0
- current_latent = None
- current_image = None
- current_image_sampling_step = 0
- id_live_preview = 0
- textinfo = None
- time_start = None
- server_start = None
- _server_command_signal = threading.Event()
- _server_command: Optional[str] = None
- def __init__(self):
- self.server_start = time.time()
- @property
- def need_restart(self) -> bool:
- # Compatibility getter for need_restart.
- return self.server_command == "restart"
- @need_restart.setter
- def need_restart(self, value: bool) -> None:
- # Compatibility setter for need_restart.
- if value:
- self.server_command = "restart"
- @property
- def server_command(self):
- return self._server_command
- @server_command.setter
- def server_command(self, value: Optional[str]) -> None:
- """
- Set the server command to `value` and signal that it's been set.
- """
- self._server_command = value
- self._server_command_signal.set()
- def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
- """
- Wait for server command to get set; return and clear the value and signal.
- """
- if self._server_command_signal.wait(timeout):
- self._server_command_signal.clear()
- req = self._server_command
- self._server_command = None
- return req
- return None
- def request_restart(self) -> None:
- self.interrupt()
- self.server_command = "restart"
- log.info("Received restart request")
- def skip(self):
- self.skipped = True
- log.info("Received skip request")
- def interrupt(self):
- self.interrupted = True
- log.info("Received interrupt request")
- def nextjob(self):
- if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
- self.do_set_current_image()
- self.job_no += 1
- self.sampling_step = 0
- self.current_image_sampling_step = 0
- def dict(self):
- obj = {
- "skipped": self.skipped,
- "interrupted": self.interrupted,
- "job": self.job,
- "job_count": self.job_count,
- "job_timestamp": self.job_timestamp,
- "job_no": self.job_no,
- "sampling_step": self.sampling_step,
- "sampling_steps": self.sampling_steps,
- }
- return obj
- def begin(self, job: str = "(unknown)"):
- self.sampling_step = 0
- self.time_start = time.time()
- self.job_count = -1
- self.processing_has_refined_job_count = False
- self.job_no = 0
- self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
- self.current_latent = None
- self.current_image = None
- self.current_image_sampling_step = 0
- self.id_live_preview = 0
- self.skipped = False
- self.interrupted = False
- self.textinfo = None
- self.job = job
- devices.torch_gc()
- log.info("Starting job %s", job)
- def end(self):
- duration = time.time() - self.time_start
- log.info("Ending job %s (%.2f seconds)", self.job, duration)
- self.job = ""
- self.job_count = 0
- devices.torch_gc()
- def set_current_image(self):
- """if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly"""
- if not shared.parallel_processing_allowed:
- return
- if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
- self.do_set_current_image()
- def do_set_current_image(self):
- if self.current_latent is None:
- return
- import modules.sd_samplers
- try:
- if shared.opts.show_progress_grid:
- self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
- else:
- self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
- self.current_image_sampling_step = self.sampling_step
- except Exception:
- # when switching models during genration, VAE would be on CPU, so creating an image will fail.
- # we silently ignore this error
- errors.record_exception()
- def assign_current_image(self, image):
- self.current_image = image
- self.id_live_preview += 1
|