webui.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. from __future__ import annotations
  2. import os
  3. import sys
  4. import time
  5. import importlib
  6. import signal
  7. import re
  8. import warnings
  9. import json
  10. from threading import Thread
  11. from typing import Iterable
  12. from fastapi import FastAPI
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from fastapi.middleware.gzip import GZipMiddleware
  15. import logging
  16. # We can't use cmd_opts for this because it will not have been initialized at this point.
  17. log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
  18. if log_level:
  19. log_level = getattr(logging, log_level.upper(), None) or logging.INFO
  20. logging.basicConfig(
  21. level=log_level,
  22. format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
  23. datefmt='%Y-%m-%d %H:%M:%S',
  24. )
  25. logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
  26. logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
  27. from modules import timer
  28. startup_timer = timer.startup_timer
  29. startup_timer.record("launcher")
  30. import torch
  31. import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
  32. warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
  33. warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
  34. startup_timer.record("import torch")
  35. import gradio # noqa: F401
  36. startup_timer.record("import gradio")
  37. from modules import paths, timer, import_hook, errors # noqa: F401
  38. startup_timer.record("setup paths")
  39. import ldm.modules.encoders.modules # noqa: F401
  40. startup_timer.record("import ldm")
  41. from modules import shared_init, shared, shared_items
  42. shared_init.initialize()
  43. startup_timer.record("initialize shared")
  44. from modules import extra_networks
  45. from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
  46. # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
  47. if ".dev" in torch.__version__ or "+git" in torch.__version__:
  48. torch.__long_version__ = torch.__version__
  49. torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
  50. if not shared.cmd_opts.skip_version_check:
  51. errors.check_versions()
  52. import modules.codeformer_model as codeformer
  53. import modules.gfpgan_model as gfpgan
  54. from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
  55. import modules.face_restoration
  56. import modules.img2img
  57. import modules.lowvram
  58. import modules.scripts
  59. import modules.sd_hijack
  60. import modules.sd_hijack_optimizations
  61. import modules.sd_models
  62. import modules.sd_vae
  63. import modules.sd_unet
  64. import modules.txt2img
  65. import modules.script_callbacks
  66. import modules.textual_inversion.textual_inversion
  67. import modules.progress
  68. import modules.ui
  69. from modules import modelloader, devices
  70. from modules.shared import cmd_opts
  71. import modules.hypernetworks.hypernetwork
  72. startup_timer.record("other imports")
  73. if cmd_opts.server_name:
  74. server_name = cmd_opts.server_name
  75. else:
  76. server_name = "0.0.0.0" if cmd_opts.listen else None
  77. def fix_asyncio_event_loop_policy():
  78. """
  79. The default `asyncio` event loop policy only automatically creates
  80. event loops in the main threads. Other threads must create event
  81. loops explicitly or `asyncio.get_event_loop` (and therefore
  82. `.IOLoop.current`) will fail. Installing this policy allows event
  83. loops to be created automatically on any thread, matching the
  84. behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
  85. """
  86. import asyncio
  87. if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
  88. # "Any thread" and "selector" should be orthogonal, but there's not a clean
  89. # interface for composing policies so pick the right base.
  90. _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
  91. else:
  92. _BasePolicy = asyncio.DefaultEventLoopPolicy
  93. class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
  94. """Event loop policy that allows loop creation on any thread.
  95. Usage::
  96. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  97. """
  98. def get_event_loop(self) -> asyncio.AbstractEventLoop:
  99. try:
  100. return super().get_event_loop()
  101. except (RuntimeError, AssertionError):
  102. # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
  103. # and changed to a RuntimeError in 3.4.3.
  104. # "There is no current event loop in thread %r"
  105. loop = self.new_event_loop()
  106. self.set_event_loop(loop)
  107. return loop
  108. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  109. def restore_config_state_file():
  110. config_state_file = shared.opts.restore_config_state_file
  111. if config_state_file == "":
  112. return
  113. shared.opts.restore_config_state_file = ""
  114. shared.opts.save(shared.config_filename)
  115. if os.path.isfile(config_state_file):
  116. print(f"*** About to restore extension state from file: {config_state_file}")
  117. with open(config_state_file, "r", encoding="utf-8") as f:
  118. config_state = json.load(f)
  119. config_states.restore_extension_config(config_state)
  120. startup_timer.record("restore extension config")
  121. elif config_state_file:
  122. print(f"!!! Config state backup not found: {config_state_file}")
  123. def validate_tls_options():
  124. if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
  125. return
  126. try:
  127. if not os.path.exists(cmd_opts.tls_keyfile):
  128. print("Invalid path to TLS keyfile given")
  129. if not os.path.exists(cmd_opts.tls_certfile):
  130. print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
  131. except TypeError:
  132. cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
  133. print("TLS setup invalid, running webui without TLS")
  134. else:
  135. print("Running with TLS")
  136. startup_timer.record("TLS")
  137. def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
  138. """
  139. Convert the gradio_auth and gradio_auth_path commandline arguments into
  140. an iterable of (username, password) tuples.
  141. """
  142. def process_credential_line(s) -> tuple[str, ...] | None:
  143. s = s.strip()
  144. if not s:
  145. return None
  146. return tuple(s.split(':', 1))
  147. if cmd_opts.gradio_auth:
  148. for cred in cmd_opts.gradio_auth.split(','):
  149. cred = process_credential_line(cred)
  150. if cred:
  151. yield cred
  152. if cmd_opts.gradio_auth_path:
  153. with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
  154. for line in file.readlines():
  155. for cred in line.strip().split(','):
  156. cred = process_credential_line(cred)
  157. if cred:
  158. yield cred
  159. def configure_sigint_handler():
  160. # make the program just exit at ctrl+c without waiting for anything
  161. def sigint_handler(sig, frame):
  162. print(f'Interrupted with signal {sig} in {frame}')
  163. os._exit(0)
  164. if not os.environ.get("COVERAGE_RUN"):
  165. # Don't install the immediate-quit handler when running under coverage,
  166. # as then the coverage report won't be generated.
  167. signal.signal(signal.SIGINT, sigint_handler)
  168. def configure_opts_onchange():
  169. shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
  170. shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
  171. shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
  172. shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
  173. shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
  174. shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
  175. startup_timer.record("opts onchange")
  176. def initialize():
  177. fix_asyncio_event_loop_policy()
  178. validate_tls_options()
  179. configure_sigint_handler()
  180. modelloader.cleanup_models()
  181. configure_opts_onchange()
  182. modules.sd_models.setup_model()
  183. startup_timer.record("setup SD model")
  184. codeformer.setup_model(cmd_opts.codeformer_models_path)
  185. startup_timer.record("setup codeformer")
  186. gfpgan.setup_model(cmd_opts.gfpgan_models_path)
  187. startup_timer.record("setup gfpgan")
  188. initialize_rest(reload_script_modules=False)
  189. def initialize_rest(*, reload_script_modules=False):
  190. """
  191. Called both from initialize() and when reloading the webui.
  192. """
  193. sd_samplers.set_samplers()
  194. extensions.list_extensions()
  195. startup_timer.record("list extensions")
  196. restore_config_state_file()
  197. if cmd_opts.ui_debug_mode:
  198. shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
  199. modules.scripts.load_scripts()
  200. return
  201. modules.sd_models.list_models()
  202. startup_timer.record("list SD models")
  203. localization.list_localizations(cmd_opts.localizations_dir)
  204. with startup_timer.subcategory("load scripts"):
  205. modules.scripts.load_scripts()
  206. if reload_script_modules:
  207. for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
  208. importlib.reload(module)
  209. startup_timer.record("reload script modules")
  210. modelloader.load_upscalers()
  211. startup_timer.record("load upscalers")
  212. modules.sd_vae.refresh_vae_list()
  213. startup_timer.record("refresh VAE")
  214. modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
  215. startup_timer.record("refresh textual inversion templates")
  216. modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
  217. modules.sd_hijack.list_optimizers()
  218. startup_timer.record("scripts list_optimizers")
  219. modules.sd_unet.list_unets()
  220. startup_timer.record("scripts list_unets")
  221. def load_model():
  222. """
  223. Accesses shared.sd_model property to load model.
  224. After it's available, if it has been loaded before this access by some extension,
  225. its optimization may be None because the list of optimizaers has neet been filled
  226. by that time, so we apply optimization again.
  227. """
  228. shared.sd_model # noqa: B018
  229. if modules.sd_hijack.current_optimizer is None:
  230. modules.sd_hijack.apply_optimizations()
  231. devices.first_time_calculation()
  232. Thread(target=load_model).start()
  233. shared_items.reload_hypernetworks()
  234. startup_timer.record("reload hypernetworks")
  235. ui_extra_networks.initialize()
  236. ui_extra_networks.register_default_pages()
  237. extra_networks.initialize()
  238. extra_networks.register_default_extra_networks()
  239. startup_timer.record("initialize extra networks")
  240. def setup_middleware(app):
  241. app.middleware_stack = None # reset current middleware to allow modifying user provided list
  242. app.add_middleware(GZipMiddleware, minimum_size=1000)
  243. configure_cors_middleware(app)
  244. app.build_middleware_stack() # rebuild middleware stack on-the-fly
  245. def configure_cors_middleware(app):
  246. cors_options = {
  247. "allow_methods": ["*"],
  248. "allow_headers": ["*"],
  249. "allow_credentials": True,
  250. }
  251. if cmd_opts.cors_allow_origins:
  252. cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
  253. if cmd_opts.cors_allow_origins_regex:
  254. cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
  255. app.add_middleware(CORSMiddleware, **cors_options)
  256. def create_api(app):
  257. from modules.api.api import Api
  258. api = Api(app, queue_lock)
  259. return api
  260. def api_only():
  261. initialize()
  262. app = FastAPI()
  263. setup_middleware(app)
  264. api = create_api(app)
  265. modules.script_callbacks.before_ui_callback()
  266. modules.script_callbacks.app_started_callback(None, app)
  267. print(f"Startup time: {startup_timer.summary()}.")
  268. api.launch(
  269. server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
  270. port=cmd_opts.port if cmd_opts.port else 7861,
  271. root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
  272. )
  273. def webui():
  274. launch_api = cmd_opts.api
  275. initialize()
  276. while 1:
  277. if shared.opts.clean_temp_dir_at_start:
  278. ui_tempdir.cleanup_tmpdr()
  279. startup_timer.record("cleanup temp dir")
  280. modules.script_callbacks.before_ui_callback()
  281. startup_timer.record("scripts before_ui_callback")
  282. shared.demo = modules.ui.create_ui()
  283. startup_timer.record("create ui")
  284. if not cmd_opts.no_gradio_queue:
  285. shared.demo.queue(64)
  286. gradio_auth_creds = list(get_gradio_auth_creds()) or None
  287. auto_launch_browser = False
  288. if os.getenv('SD_WEBUI_RESTARTING') != '1':
  289. if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch:
  290. auto_launch_browser = True
  291. elif shared.opts.auto_launch_browser == "Local":
  292. auto_launch_browser = not any([cmd_opts.listen, cmd_opts.share, cmd_opts.ngrok])
  293. app, local_url, share_url = shared.demo.launch(
  294. share=cmd_opts.share,
  295. server_name=server_name,
  296. server_port=cmd_opts.port,
  297. ssl_keyfile=cmd_opts.tls_keyfile,
  298. ssl_certfile=cmd_opts.tls_certfile,
  299. ssl_verify=cmd_opts.disable_tls_verify,
  300. debug=cmd_opts.gradio_debug,
  301. auth=gradio_auth_creds,
  302. inbrowser=auto_launch_browser,
  303. prevent_thread_lock=True,
  304. allowed_paths=cmd_opts.gradio_allowed_path,
  305. app_kwargs={
  306. "docs_url": "/docs",
  307. "redoc_url": "/redoc",
  308. },
  309. root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
  310. )
  311. startup_timer.record("gradio launch")
  312. # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
  313. # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
  314. # running web ui and do whatever the attacker wants, including installing an extension and
  315. # running its code. We disable this here. Suggested by RyotaK.
  316. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
  317. setup_middleware(app)
  318. modules.progress.setup_progress_api(app)
  319. modules.ui.setup_ui_api(app)
  320. if launch_api:
  321. create_api(app)
  322. ui_extra_networks.add_pages_to_demo(app)
  323. startup_timer.record("add APIs")
  324. with startup_timer.subcategory("app_started_callback"):
  325. modules.script_callbacks.app_started_callback(shared.demo, app)
  326. timer.startup_record = startup_timer.dump()
  327. print(f"Startup time: {startup_timer.summary()}.")
  328. try:
  329. while True:
  330. server_command = shared.state.wait_for_server_command(timeout=5)
  331. if server_command:
  332. if server_command in ("stop", "restart"):
  333. break
  334. else:
  335. print(f"Unknown server command: {server_command}")
  336. except KeyboardInterrupt:
  337. print('Caught KeyboardInterrupt, stopping...')
  338. server_command = "stop"
  339. if server_command == "stop":
  340. print("Stopping server...")
  341. # If we catch a keyboard interrupt, we want to stop the server and exit.
  342. shared.demo.close()
  343. break
  344. # disable auto launch webui in browser for subsequent UI Reload
  345. os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
  346. print('Restarting UI...')
  347. shared.demo.close()
  348. time.sleep(0.5)
  349. startup_timer.reset()
  350. modules.script_callbacks.app_reload_callback()
  351. startup_timer.record("app reload callback")
  352. modules.script_callbacks.script_unloaded_callback()
  353. startup_timer.record("scripts unloaded callback")
  354. initialize_rest(reload_script_modules=True)
  355. if __name__ == "__main__":
  356. if cmd_opts.nowebui:
  357. api_only()
  358. else:
  359. webui()