progress.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import base64
  2. import io
  3. import time
  4. import gradio as gr
  5. from pydantic import BaseModel, Field
  6. from modules.shared import opts
  7. import modules.shared as shared
  8. from collections import OrderedDict
  9. import string
  10. import random
  11. from typing import List
  12. current_task = None
  13. pending_tasks = OrderedDict()
  14. finished_tasks = []
  15. recorded_results = []
  16. recorded_results_limit = 2
  17. def start_task(id_task):
  18. global current_task
  19. current_task = id_task
  20. pending_tasks.pop(id_task, None)
  21. def finish_task(id_task):
  22. global current_task
  23. if current_task == id_task:
  24. current_task = None
  25. finished_tasks.append(id_task)
  26. if len(finished_tasks) > 16:
  27. finished_tasks.pop(0)
  28. def create_task_id(task_type):
  29. N = 7
  30. res = ''.join(random.choices(string.ascii_uppercase +
  31. string.digits, k=N))
  32. return f"task({task_type}-{res})"
  33. def record_results(id_task, res):
  34. recorded_results.append((id_task, res))
  35. if len(recorded_results) > recorded_results_limit:
  36. recorded_results.pop(0)
  37. def add_task_to_queue(id_job):
  38. pending_tasks[id_job] = time.time()
  39. class PendingTasksResponse(BaseModel):
  40. size: int = Field(title="Pending task size")
  41. tasks: List[str] = Field(title="Pending task ids")
  42. class ProgressRequest(BaseModel):
  43. id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for")
  44. id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image")
  45. live_preview: bool = Field(default=True, title="Include live preview", description="boolean flag indicating whether to include the live preview image")
  46. class ProgressResponse(BaseModel):
  47. active: bool = Field(title="Whether the task is being worked on right now")
  48. queued: bool = Field(title="Whether the task is in queue")
  49. completed: bool = Field(title="Whether the task has already finished")
  50. progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1")
  51. eta: float = Field(default=None, title="ETA in secs")
  52. live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri")
  53. id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image")
  54. textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.")
  55. def setup_progress_api(app):
  56. app.add_api_route("/internal/pending-tasks", get_pending_tasks, methods=["GET"])
  57. return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse)
  58. def get_pending_tasks():
  59. pending_tasks_ids = list(pending_tasks)
  60. pending_len = len(pending_tasks_ids)
  61. return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)
  62. def progressapi(req: ProgressRequest):
  63. active = req.id_task == current_task
  64. queued = req.id_task in pending_tasks
  65. completed = req.id_task in finished_tasks
  66. if not active:
  67. textinfo = "Waiting..."
  68. if queued:
  69. sorted_queued = sorted(pending_tasks.keys(), key=lambda x: pending_tasks[x])
  70. queue_index = sorted_queued.index(req.id_task)
  71. textinfo = "In queue: {}/{}".format(queue_index + 1, len(sorted_queued))
  72. return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo=textinfo)
  73. progress = 0
  74. job_count, job_no = shared.state.job_count, shared.state.job_no
  75. sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step
  76. if job_count > 0:
  77. progress += job_no / job_count
  78. if sampling_steps > 0 and job_count > 0:
  79. progress += 1 / job_count * sampling_step / sampling_steps
  80. progress = min(progress, 1)
  81. elapsed_since_start = time.time() - shared.state.time_start
  82. predicted_duration = elapsed_since_start / progress if progress > 0 else None
  83. eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None
  84. live_preview = None
  85. id_live_preview = req.id_live_preview
  86. if opts.live_previews_enable and req.live_preview:
  87. shared.state.set_current_image()
  88. if shared.state.id_live_preview != req.id_live_preview:
  89. image = shared.state.current_image
  90. if image is not None:
  91. buffered = io.BytesIO()
  92. if opts.live_previews_image_format == "png":
  93. # using optimize for large images takes an enormous amount of time
  94. if max(*image.size) <= 256:
  95. save_kwargs = {"optimize": True}
  96. else:
  97. save_kwargs = {"optimize": False, "compress_level": 1}
  98. else:
  99. save_kwargs = {}
  100. image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
  101. base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
  102. live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
  103. id_live_preview = shared.state.id_live_preview
  104. return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
  105. def restore_progress(id_task):
  106. while id_task == current_task or id_task in pending_tasks:
  107. time.sleep(0.1)
  108. res = next(iter([x[1] for x in recorded_results if id_task == x[0]]), None)
  109. if res is not None:
  110. return res
  111. return gr.update(), gr.update(), gr.update(), f"Couldn't restore progress for {id_task}: results either have been discarded or never were obtained"