initialize_util.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import json
  2. import logging
  3. import os
  4. import signal
  5. import sys
  6. import re
  7. from modules.timer import startup_timer
  8. def setup_logging():
  9. # We can't use cmd_opts for this because it will not have been initialized at this point.
  10. log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
  11. if log_level:
  12. log_level = getattr(logging, log_level.upper(), None) or logging.INFO
  13. logging.basicConfig(
  14. level=log_level,
  15. format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
  16. datefmt='%Y-%m-%d %H:%M:%S',
  17. )
  18. def gradio_server_name():
  19. from modules.shared_cmd_options import cmd_opts
  20. if cmd_opts.server_name:
  21. return cmd_opts.server_name
  22. else:
  23. return "0.0.0.0" if cmd_opts.listen else None
  24. def fix_torch_version():
  25. import torch
  26. # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
  27. if ".dev" in torch.__version__ or "+git" in torch.__version__:
  28. torch.__long_version__ = torch.__version__
  29. torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
  30. def fix_asyncio_event_loop_policy():
  31. """
  32. The default `asyncio` event loop policy only automatically creates
  33. event loops in the main threads. Other threads must create event
  34. loops explicitly or `asyncio.get_event_loop` (and therefore
  35. `.IOLoop.current`) will fail. Installing this policy allows event
  36. loops to be created automatically on any thread, matching the
  37. behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
  38. """
  39. import asyncio
  40. if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
  41. # "Any thread" and "selector" should be orthogonal, but there's not a clean
  42. # interface for composing policies so pick the right base.
  43. _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
  44. else:
  45. _BasePolicy = asyncio.DefaultEventLoopPolicy
  46. class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
  47. """Event loop policy that allows loop creation on any thread.
  48. Usage::
  49. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  50. """
  51. def get_event_loop(self) -> asyncio.AbstractEventLoop:
  52. try:
  53. return super().get_event_loop()
  54. except (RuntimeError, AssertionError):
  55. # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
  56. # and changed to a RuntimeError in 3.4.3.
  57. # "There is no current event loop in thread %r"
  58. loop = self.new_event_loop()
  59. self.set_event_loop(loop)
  60. return loop
  61. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  62. def restore_config_state_file():
  63. from modules import shared, config_states
  64. config_state_file = shared.opts.restore_config_state_file
  65. if config_state_file == "":
  66. return
  67. shared.opts.restore_config_state_file = ""
  68. shared.opts.save(shared.config_filename)
  69. if os.path.isfile(config_state_file):
  70. print(f"*** About to restore extension state from file: {config_state_file}")
  71. with open(config_state_file, "r", encoding="utf-8") as f:
  72. config_state = json.load(f)
  73. config_states.restore_extension_config(config_state)
  74. startup_timer.record("restore extension config")
  75. elif config_state_file:
  76. print(f"!!! Config state backup not found: {config_state_file}")
  77. def validate_tls_options():
  78. from modules.shared_cmd_options import cmd_opts
  79. if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
  80. return
  81. try:
  82. if not os.path.exists(cmd_opts.tls_keyfile):
  83. print("Invalid path to TLS keyfile given")
  84. if not os.path.exists(cmd_opts.tls_certfile):
  85. print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
  86. except TypeError:
  87. cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
  88. print("TLS setup invalid, running webui without TLS")
  89. else:
  90. print("Running with TLS")
  91. startup_timer.record("TLS")
  92. def get_gradio_auth_creds():
  93. """
  94. Convert the gradio_auth and gradio_auth_path commandline arguments into
  95. an iterable of (username, password) tuples.
  96. """
  97. from modules.shared_cmd_options import cmd_opts
  98. def process_credential_line(s):
  99. s = s.strip()
  100. if not s:
  101. return None
  102. return tuple(s.split(':', 1))
  103. if cmd_opts.gradio_auth:
  104. for cred in cmd_opts.gradio_auth.split(','):
  105. cred = process_credential_line(cred)
  106. if cred:
  107. yield cred
  108. if cmd_opts.gradio_auth_path:
  109. with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
  110. for line in file.readlines():
  111. for cred in line.strip().split(','):
  112. cred = process_credential_line(cred)
  113. if cred:
  114. yield cred
  115. def configure_sigint_handler():
  116. # make the program just exit at ctrl+c without waiting for anything
  117. def sigint_handler(sig, frame):
  118. print(f'Interrupted with signal {sig} in {frame}')
  119. os._exit(0)
  120. if not os.environ.get("COVERAGE_RUN"):
  121. # Don't install the immediate-quit handler when running under coverage,
  122. # as then the coverage report won't be generated.
  123. signal.signal(signal.SIGINT, sigint_handler)
  124. def configure_opts_onchange():
  125. from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
  126. from modules.call_queue import wrap_queued_call
  127. shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
  128. shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
  129. shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
  130. shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
  131. shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
  132. shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
  133. startup_timer.record("opts onchange")
  134. def setup_middleware(app):
  135. from starlette.middleware.gzip import GZipMiddleware
  136. app.middleware_stack = None # reset current middleware to allow modifying user provided list
  137. app.add_middleware(GZipMiddleware, minimum_size=1000)
  138. configure_cors_middleware(app)
  139. app.build_middleware_stack() # rebuild middleware stack on-the-fly
  140. def configure_cors_middleware(app):
  141. from starlette.middleware.cors import CORSMiddleware
  142. from modules.shared_cmd_options import cmd_opts
  143. cors_options = {
  144. "allow_methods": ["*"],
  145. "allow_headers": ["*"],
  146. "allow_credentials": True,
  147. }
  148. if cmd_opts.cors_allow_origins:
  149. cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
  150. if cmd_opts.cors_allow_origins_regex:
  151. cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
  152. app.add_middleware(CORSMiddleware, **cors_options)