initialize_util.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import json
  2. import os
  3. import signal
  4. import sys
  5. import re
  6. from modules.timer import startup_timer
  7. from modules import patches
  8. from functools import wraps
  9. def gradio_server_name():
  10. from modules.shared_cmd_options import cmd_opts
  11. if cmd_opts.server_name:
  12. return cmd_opts.server_name
  13. else:
  14. return "0.0.0.0" if cmd_opts.listen else None
  15. def fix_torch_version():
  16. import torch
  17. # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
  18. if ".dev" in torch.__version__ or "+git" in torch.__version__:
  19. torch.__long_version__ = torch.__version__
  20. torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
  21. def fix_pytorch_lightning():
  22. # Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache
  23. if 'pytorch_lightning.utilities.distributed' not in sys.modules:
  24. import pytorch_lightning
  25. # Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero
  26. print("Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero")
  27. sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero
  28. def fix_asyncio_event_loop_policy():
  29. """
  30. The default `asyncio` event loop policy only automatically creates
  31. event loops in the main threads. Other threads must create event
  32. loops explicitly or `asyncio.get_event_loop` (and therefore
  33. `.IOLoop.current`) will fail. Installing this policy allows event
  34. loops to be created automatically on any thread, matching the
  35. behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
  36. """
  37. import asyncio
  38. if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
  39. # "Any thread" and "selector" should be orthogonal, but there's not a clean
  40. # interface for composing policies so pick the right base.
  41. _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
  42. else:
  43. _BasePolicy = asyncio.DefaultEventLoopPolicy
  44. class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
  45. """Event loop policy that allows loop creation on any thread.
  46. Usage::
  47. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  48. """
  49. def get_event_loop(self) -> asyncio.AbstractEventLoop:
  50. try:
  51. return super().get_event_loop()
  52. except (RuntimeError, AssertionError):
  53. # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
  54. # and changed to a RuntimeError in 3.4.3.
  55. # "There is no current event loop in thread %r"
  56. loop = self.new_event_loop()
  57. self.set_event_loop(loop)
  58. return loop
  59. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  60. def restore_config_state_file():
  61. from modules import shared, config_states
  62. config_state_file = shared.opts.restore_config_state_file
  63. if config_state_file == "":
  64. return
  65. shared.opts.restore_config_state_file = ""
  66. shared.opts.save(shared.config_filename)
  67. if os.path.isfile(config_state_file):
  68. print(f"*** About to restore extension state from file: {config_state_file}")
  69. with open(config_state_file, "r", encoding="utf-8") as f:
  70. config_state = json.load(f)
  71. config_states.restore_extension_config(config_state)
  72. startup_timer.record("restore extension config")
  73. elif config_state_file:
  74. print(f"!!! Config state backup not found: {config_state_file}")
  75. def validate_tls_options():
  76. from modules.shared_cmd_options import cmd_opts
  77. if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
  78. return
  79. try:
  80. if not os.path.exists(cmd_opts.tls_keyfile):
  81. print("Invalid path to TLS keyfile given")
  82. if not os.path.exists(cmd_opts.tls_certfile):
  83. print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
  84. except TypeError:
  85. cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
  86. print("TLS setup invalid, running webui without TLS")
  87. else:
  88. print("Running with TLS")
  89. startup_timer.record("TLS")
  90. def get_gradio_auth_creds():
  91. """
  92. Convert the gradio_auth and gradio_auth_path commandline arguments into
  93. an iterable of (username, password) tuples.
  94. """
  95. from modules.shared_cmd_options import cmd_opts
  96. def process_credential_line(s):
  97. s = s.strip()
  98. if not s:
  99. return None
  100. return tuple(s.split(':', 1))
  101. if cmd_opts.gradio_auth:
  102. for cred in cmd_opts.gradio_auth.split(','):
  103. cred = process_credential_line(cred)
  104. if cred:
  105. yield cred
  106. if cmd_opts.gradio_auth_path:
  107. with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
  108. for line in file.readlines():
  109. for cred in line.strip().split(','):
  110. cred = process_credential_line(cred)
  111. if cred:
  112. yield cred
  113. def dumpstacks():
  114. import threading
  115. import traceback
  116. id2name = {th.ident: th.name for th in threading.enumerate()}
  117. code = []
  118. for threadId, stack in sys._current_frames().items():
  119. code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})")
  120. for filename, lineno, name, line in traceback.extract_stack(stack):
  121. code.append(f"""File: "{filename}", line {lineno}, in {name}""")
  122. if line:
  123. code.append(" " + line.strip())
  124. print("\n".join(code))
  125. def configure_sigint_handler():
  126. # make the program just exit at ctrl+c without waiting for anything
  127. from modules import shared
  128. def sigint_handler(sig, frame):
  129. print(f'Interrupted with signal {sig} in {frame}')
  130. if shared.opts.dump_stacks_on_signal:
  131. dumpstacks()
  132. os._exit(0)
  133. if not os.environ.get("COVERAGE_RUN"):
  134. # Don't install the immediate-quit handler when running under coverage,
  135. # as then the coverage report won't be generated.
  136. signal.signal(signal.SIGINT, sigint_handler)
  137. def configure_opts_onchange():
  138. from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
  139. from modules.call_queue import wrap_queued_call
  140. shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
  141. shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
  142. shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
  143. shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
  144. shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
  145. shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
  146. shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
  147. shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
  148. startup_timer.record("opts onchange")
  149. def setup_middleware(app):
  150. from starlette.middleware.gzip import GZipMiddleware
  151. app.add_middleware(GZipMiddleware, minimum_size=1000)
  152. configure_cors_middleware(app)
  153. def configure_cors_middleware(app):
  154. from starlette.middleware.cors import CORSMiddleware
  155. from modules.shared_cmd_options import cmd_opts
  156. cors_options = {
  157. "allow_methods": ["*"],
  158. "allow_headers": ["*"],
  159. "allow_credentials": True,
  160. }
  161. if cmd_opts.cors_allow_origins:
  162. cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
  163. if cmd_opts.cors_allow_origins_regex:
  164. cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
  165. app.add_middleware(CORSMiddleware, **cors_options)
  166. def allow_add_middleware_after_start():
  167. from starlette.applications import Starlette
  168. def add_middleware_wrapper(func):
  169. """Patch Starlette.add_middleware to allow for middleware to be added after the app has started
  170. Starlette.add_middleware raises RuntimeError("Cannot add middleware after an application has started") if middleware_stack is not None.
  171. We can force add new middleware by first setting middleware_stack to None, then adding the middleware.
  172. When middleware_stack is None, it will rebuild the middleware_stack on the next request (Lazily build middleware stack).
  173. If packages are updated in the future, things may break, so we have two ways to add middleware after the app has started:
  174. the first way is to just set middleware_stack to None and then retry
  175. the second manually insert the middleware into the user_middleware list without calling add_middleware
  176. """
  177. @wraps(func)
  178. def wrapper(self, *args, **kwargs):
  179. res = None
  180. try:
  181. res = func(self, *args, **kwargs)
  182. except RuntimeError as _:
  183. try:
  184. self.middleware_stack = None
  185. res = func(self, *args, **kwargs)
  186. except RuntimeError as e:
  187. print(f'Warning: "{e}", Retrying...')
  188. from starlette.middleware import Middleware
  189. self.user_middleware.insert(0, Middleware(*args, **kwargs))
  190. self.middleware_stack = None # ensure middleware_stack in the event of concurrent requests
  191. return res
  192. return wrapper
  193. patches.patch(__name__, obj=Starlette, field="add_middleware", replacement=add_middleware_wrapper(Starlette.add_middleware))