webui.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. from __future__ import annotations
  2. import os
  3. import sys
  4. import time
  5. import importlib
  6. import signal
  7. import re
  8. import warnings
  9. import json
  10. from threading import Thread
  11. from typing import Iterable
  12. from fastapi import FastAPI
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from fastapi.middleware.gzip import GZipMiddleware
  15. import logging
  16. # We can't use cmd_opts for this because it will not have been initialized at this point.
  17. log_level = os.environ.get("SD_WEBUI_LOG_LEVEL")
  18. if log_level:
  19. log_level = getattr(logging, log_level.upper(), None) or logging.INFO
  20. logging.basicConfig(
  21. level=log_level,
  22. format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
  23. datefmt='%Y-%m-%d %H:%M:%S',
  24. )
  25. logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
  26. logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
  27. from modules import timer
  28. startup_timer = timer.startup_timer
  29. startup_timer.record("launcher")
  30. import torch
  31. import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
  32. warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
  33. warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
  34. startup_timer.record("import torch")
  35. import gradio # noqa: F401
  36. startup_timer.record("import gradio")
  37. from modules import paths, timer, import_hook, errors, devices # noqa: F401
  38. startup_timer.record("setup paths")
  39. import ldm.modules.encoders.modules # noqa: F401
  40. startup_timer.record("import ldm")
  41. from modules import extra_networks
  42. from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
  43. # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
  44. if ".dev" in torch.__version__ or "+git" in torch.__version__:
  45. torch.__long_version__ = torch.__version__
  46. torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
  47. from modules import shared
  48. if not shared.cmd_opts.skip_version_check:
  49. errors.check_versions()
  50. import modules.codeformer_model as codeformer
  51. import modules.gfpgan_model as gfpgan
  52. from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
  53. import modules.face_restoration
  54. import modules.img2img
  55. import modules.lowvram
  56. import modules.scripts
  57. import modules.sd_hijack
  58. import modules.sd_hijack_optimizations
  59. import modules.sd_models
  60. import modules.sd_vae
  61. import modules.sd_unet
  62. import modules.txt2img
  63. import modules.script_callbacks
  64. import modules.textual_inversion.textual_inversion
  65. import modules.progress
  66. import modules.ui
  67. from modules import modelloader
  68. from modules.shared import cmd_opts
  69. import modules.hypernetworks.hypernetwork
  70. startup_timer.record("other imports")
  71. if cmd_opts.server_name:
  72. server_name = cmd_opts.server_name
  73. else:
  74. server_name = "0.0.0.0" if cmd_opts.listen else None
  75. def fix_asyncio_event_loop_policy():
  76. """
  77. The default `asyncio` event loop policy only automatically creates
  78. event loops in the main threads. Other threads must create event
  79. loops explicitly or `asyncio.get_event_loop` (and therefore
  80. `.IOLoop.current`) will fail. Installing this policy allows event
  81. loops to be created automatically on any thread, matching the
  82. behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
  83. """
  84. import asyncio
  85. if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
  86. # "Any thread" and "selector" should be orthogonal, but there's not a clean
  87. # interface for composing policies so pick the right base.
  88. _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
  89. else:
  90. _BasePolicy = asyncio.DefaultEventLoopPolicy
  91. class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
  92. """Event loop policy that allows loop creation on any thread.
  93. Usage::
  94. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  95. """
  96. def get_event_loop(self) -> asyncio.AbstractEventLoop:
  97. try:
  98. return super().get_event_loop()
  99. except (RuntimeError, AssertionError):
  100. # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
  101. # and changed to a RuntimeError in 3.4.3.
  102. # "There is no current event loop in thread %r"
  103. loop = self.new_event_loop()
  104. self.set_event_loop(loop)
  105. return loop
  106. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  107. def restore_config_state_file():
  108. config_state_file = shared.opts.restore_config_state_file
  109. if config_state_file == "":
  110. return
  111. shared.opts.restore_config_state_file = ""
  112. shared.opts.save(shared.config_filename)
  113. if os.path.isfile(config_state_file):
  114. print(f"*** About to restore extension state from file: {config_state_file}")
  115. with open(config_state_file, "r", encoding="utf-8") as f:
  116. config_state = json.load(f)
  117. config_states.restore_extension_config(config_state)
  118. startup_timer.record("restore extension config")
  119. elif config_state_file:
  120. print(f"!!! Config state backup not found: {config_state_file}")
  121. def validate_tls_options():
  122. if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
  123. return
  124. try:
  125. if not os.path.exists(cmd_opts.tls_keyfile):
  126. print("Invalid path to TLS keyfile given")
  127. if not os.path.exists(cmd_opts.tls_certfile):
  128. print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
  129. except TypeError:
  130. cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
  131. print("TLS setup invalid, running webui without TLS")
  132. else:
  133. print("Running with TLS")
  134. startup_timer.record("TLS")
  135. def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
  136. """
  137. Convert the gradio_auth and gradio_auth_path commandline arguments into
  138. an iterable of (username, password) tuples.
  139. """
  140. def process_credential_line(s) -> tuple[str, ...] | None:
  141. s = s.strip()
  142. if not s:
  143. return None
  144. return tuple(s.split(':', 1))
  145. if cmd_opts.gradio_auth:
  146. for cred in cmd_opts.gradio_auth.split(','):
  147. cred = process_credential_line(cred)
  148. if cred:
  149. yield cred
  150. if cmd_opts.gradio_auth_path:
  151. with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
  152. for line in file.readlines():
  153. for cred in line.strip().split(','):
  154. cred = process_credential_line(cred)
  155. if cred:
  156. yield cred
  157. def configure_sigint_handler():
  158. # make the program just exit at ctrl+c without waiting for anything
  159. def sigint_handler(sig, frame):
  160. print(f'Interrupted with signal {sig} in {frame}')
  161. os._exit(0)
  162. if not os.environ.get("COVERAGE_RUN"):
  163. # Don't install the immediate-quit handler when running under coverage,
  164. # as then the coverage report won't be generated.
  165. signal.signal(signal.SIGINT, sigint_handler)
  166. def configure_opts_onchange():
  167. shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
  168. shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
  169. shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
  170. shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
  171. shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
  172. shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
  173. startup_timer.record("opts onchange")
  174. def initialize():
  175. fix_asyncio_event_loop_policy()
  176. validate_tls_options()
  177. configure_sigint_handler()
  178. modelloader.cleanup_models()
  179. configure_opts_onchange()
  180. modules.sd_models.setup_model()
  181. startup_timer.record("setup SD model")
  182. codeformer.setup_model(cmd_opts.codeformer_models_path)
  183. startup_timer.record("setup codeformer")
  184. gfpgan.setup_model(cmd_opts.gfpgan_models_path)
  185. startup_timer.record("setup gfpgan")
  186. initialize_rest(reload_script_modules=False)
  187. def initialize_rest(*, reload_script_modules=False):
  188. """
  189. Called both from initialize() and when reloading the webui.
  190. """
  191. sd_samplers.set_samplers()
  192. extensions.list_extensions()
  193. startup_timer.record("list extensions")
  194. restore_config_state_file()
  195. if cmd_opts.ui_debug_mode:
  196. shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
  197. modules.scripts.load_scripts()
  198. return
  199. modules.sd_models.list_models()
  200. startup_timer.record("list SD models")
  201. localization.list_localizations(cmd_opts.localizations_dir)
  202. with startup_timer.subcategory("load scripts"):
  203. modules.scripts.load_scripts()
  204. if reload_script_modules:
  205. for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
  206. importlib.reload(module)
  207. startup_timer.record("reload script modules")
  208. modelloader.load_upscalers()
  209. startup_timer.record("load upscalers")
  210. modules.sd_vae.refresh_vae_list()
  211. startup_timer.record("refresh VAE")
  212. modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
  213. startup_timer.record("refresh textual inversion templates")
  214. modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
  215. modules.sd_hijack.list_optimizers()
  216. startup_timer.record("scripts list_optimizers")
  217. modules.sd_unet.list_unets()
  218. startup_timer.record("scripts list_unets")
  219. def load_model():
  220. """
  221. Accesses shared.sd_model property to load model.
  222. After it's available, if it has been loaded before this access by some extension,
  223. its optimization may be None because the list of optimizaers has neet been filled
  224. by that time, so we apply optimization again.
  225. """
  226. shared.sd_model # noqa: B018
  227. if modules.sd_hijack.current_optimizer is None:
  228. modules.sd_hijack.apply_optimizations()
  229. devices.first_time_calculation()
  230. Thread(target=load_model).start()
  231. shared.reload_hypernetworks()
  232. startup_timer.record("reload hypernetworks")
  233. ui_extra_networks.initialize()
  234. ui_extra_networks.register_default_pages()
  235. extra_networks.initialize()
  236. extra_networks.register_default_extra_networks()
  237. startup_timer.record("initialize extra networks")
  238. def setup_middleware(app):
  239. app.middleware_stack = None # reset current middleware to allow modifying user provided list
  240. app.add_middleware(GZipMiddleware, minimum_size=1000)
  241. configure_cors_middleware(app)
  242. app.build_middleware_stack() # rebuild middleware stack on-the-fly
  243. def configure_cors_middleware(app):
  244. cors_options = {
  245. "allow_methods": ["*"],
  246. "allow_headers": ["*"],
  247. "allow_credentials": True,
  248. }
  249. if cmd_opts.cors_allow_origins:
  250. cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
  251. if cmd_opts.cors_allow_origins_regex:
  252. cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
  253. app.add_middleware(CORSMiddleware, **cors_options)
  254. def create_api(app):
  255. from modules.api.api import Api
  256. api = Api(app, queue_lock)
  257. return api
  258. def api_only():
  259. initialize()
  260. app = FastAPI()
  261. setup_middleware(app)
  262. api = create_api(app)
  263. modules.script_callbacks.app_started_callback(None, app)
  264. print(f"Startup time: {startup_timer.summary()}.")
  265. api.launch(
  266. server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
  267. port=cmd_opts.port if cmd_opts.port else 7861,
  268. root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
  269. )
  270. def webui():
  271. launch_api = cmd_opts.api
  272. initialize()
  273. while 1:
  274. if shared.opts.clean_temp_dir_at_start:
  275. ui_tempdir.cleanup_tmpdr()
  276. startup_timer.record("cleanup temp dir")
  277. modules.script_callbacks.before_ui_callback()
  278. startup_timer.record("scripts before_ui_callback")
  279. shared.demo = modules.ui.create_ui()
  280. startup_timer.record("create ui")
  281. if not cmd_opts.no_gradio_queue:
  282. shared.demo.queue(64)
  283. gradio_auth_creds = list(get_gradio_auth_creds()) or None
  284. app, local_url, share_url = shared.demo.launch(
  285. share=cmd_opts.share,
  286. server_name=server_name,
  287. server_port=cmd_opts.port,
  288. ssl_keyfile=cmd_opts.tls_keyfile,
  289. ssl_certfile=cmd_opts.tls_certfile,
  290. ssl_verify=cmd_opts.disable_tls_verify,
  291. debug=cmd_opts.gradio_debug,
  292. auth=gradio_auth_creds,
  293. inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING') != '1',
  294. prevent_thread_lock=True,
  295. allowed_paths=cmd_opts.gradio_allowed_path,
  296. app_kwargs={
  297. "docs_url": "/docs",
  298. "redoc_url": "/redoc",
  299. },
  300. root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
  301. )
  302. # after initial launch, disable --autolaunch for subsequent restarts
  303. cmd_opts.autolaunch = False
  304. startup_timer.record("gradio launch")
  305. # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
  306. # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
  307. # running web ui and do whatever the attacker wants, including installing an extension and
  308. # running its code. We disable this here. Suggested by RyotaK.
  309. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
  310. setup_middleware(app)
  311. modules.progress.setup_progress_api(app)
  312. modules.ui.setup_ui_api(app)
  313. if launch_api:
  314. create_api(app)
  315. ui_extra_networks.add_pages_to_demo(app)
  316. startup_timer.record("add APIs")
  317. with startup_timer.subcategory("app_started_callback"):
  318. modules.script_callbacks.app_started_callback(shared.demo, app)
  319. timer.startup_record = startup_timer.dump()
  320. print(f"Startup time: {startup_timer.summary()}.")
  321. try:
  322. while True:
  323. server_command = shared.state.wait_for_server_command(timeout=5)
  324. if server_command:
  325. if server_command in ("stop", "restart"):
  326. break
  327. else:
  328. print(f"Unknown server command: {server_command}")
  329. except KeyboardInterrupt:
  330. print('Caught KeyboardInterrupt, stopping...')
  331. server_command = "stop"
  332. if server_command == "stop":
  333. print("Stopping server...")
  334. # If we catch a keyboard interrupt, we want to stop the server and exit.
  335. shared.demo.close()
  336. break
  337. print('Restarting UI...')
  338. shared.demo.close()
  339. time.sleep(0.5)
  340. startup_timer.reset()
  341. modules.script_callbacks.app_reload_callback()
  342. startup_timer.record("app reload callback")
  343. modules.script_callbacks.script_unloaded_callback()
  344. startup_timer.record("scripts unloaded callback")
  345. initialize_rest(reload_script_modules=True)
  346. if __name__ == "__main__":
  347. if cmd_opts.nowebui:
  348. api_only()
  349. else:
  350. webui()