webui.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. from __future__ import annotations
  2. import os
  3. import sys
  4. import time
  5. import importlib
  6. import signal
  7. import re
  8. import warnings
  9. import json
  10. from threading import Thread
  11. from typing import Iterable
  12. from fastapi import FastAPI, Response
  13. from fastapi.middleware.cors import CORSMiddleware
  14. from fastapi.middleware.gzip import GZipMiddleware
  15. from packaging import version
  16. import logging
  17. logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
  18. from modules import paths, timer, import_hook, errors, devices # noqa: F401
  19. startup_timer = timer.startup_timer
  20. import torch
  21. import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
  22. warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
  23. warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
  24. startup_timer.record("import torch")
  25. import gradio
  26. startup_timer.record("import gradio")
  27. import ldm.modules.encoders.modules # noqa: F401
  28. startup_timer.record("import ldm")
  29. from modules import extra_networks
  30. from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
  31. # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
  32. if ".dev" in torch.__version__ or "+git" in torch.__version__:
  33. torch.__long_version__ = torch.__version__
  34. torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
  35. from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
  36. import modules.codeformer_model as codeformer
  37. import modules.face_restoration
  38. import modules.gfpgan_model as gfpgan
  39. import modules.img2img
  40. import modules.lowvram
  41. import modules.scripts
  42. import modules.sd_hijack
  43. import modules.sd_hijack_optimizations
  44. import modules.sd_models
  45. import modules.sd_vae
  46. import modules.sd_unet
  47. import modules.txt2img
  48. import modules.script_callbacks
  49. import modules.textual_inversion.textual_inversion
  50. import modules.progress
  51. import modules.ui
  52. from modules import modelloader
  53. from modules.shared import cmd_opts
  54. import modules.hypernetworks.hypernetwork
  55. startup_timer.record("other imports")
  56. if cmd_opts.server_name:
  57. server_name = cmd_opts.server_name
  58. else:
  59. server_name = "0.0.0.0" if cmd_opts.listen else None
  60. def fix_asyncio_event_loop_policy():
  61. """
  62. The default `asyncio` event loop policy only automatically creates
  63. event loops in the main threads. Other threads must create event
  64. loops explicitly or `asyncio.get_event_loop` (and therefore
  65. `.IOLoop.current`) will fail. Installing this policy allows event
  66. loops to be created automatically on any thread, matching the
  67. behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
  68. """
  69. import asyncio
  70. if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
  71. # "Any thread" and "selector" should be orthogonal, but there's not a clean
  72. # interface for composing policies so pick the right base.
  73. _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
  74. else:
  75. _BasePolicy = asyncio.DefaultEventLoopPolicy
  76. class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
  77. """Event loop policy that allows loop creation on any thread.
  78. Usage::
  79. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  80. """
  81. def get_event_loop(self) -> asyncio.AbstractEventLoop:
  82. try:
  83. return super().get_event_loop()
  84. except (RuntimeError, AssertionError):
  85. # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
  86. # and changed to a RuntimeError in 3.4.3.
  87. # "There is no current event loop in thread %r"
  88. loop = self.new_event_loop()
  89. self.set_event_loop(loop)
  90. return loop
  91. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  92. def check_versions():
  93. if shared.cmd_opts.skip_version_check:
  94. return
  95. expected_torch_version = "2.0.0"
  96. if version.parse(torch.__version__) < version.parse(expected_torch_version):
  97. errors.print_error_explanation(f"""
  98. You are running torch {torch.__version__}.
  99. The program is tested to work with torch {expected_torch_version}.
  100. To reinstall the desired version, run with commandline flag --reinstall-torch.
  101. Beware that this will cause a lot of large files to be downloaded, as well as
  102. there are reports of issues with training tab on the latest version.
  103. Use --skip-version-check commandline argument to disable this check.
  104. """.strip())
  105. expected_xformers_version = "0.0.20"
  106. if shared.xformers_available:
  107. import xformers
  108. if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
  109. errors.print_error_explanation(f"""
  110. You are running xformers {xformers.__version__}.
  111. The program is tested to work with xformers {expected_xformers_version}.
  112. To reinstall the desired version, run with commandline flag --reinstall-xformers.
  113. Use --skip-version-check commandline argument to disable this check.
  114. """.strip())
  115. def restore_config_state_file():
  116. config_state_file = shared.opts.restore_config_state_file
  117. if config_state_file == "":
  118. return
  119. shared.opts.restore_config_state_file = ""
  120. shared.opts.save(shared.config_filename)
  121. if os.path.isfile(config_state_file):
  122. print(f"*** About to restore extension state from file: {config_state_file}")
  123. with open(config_state_file, "r", encoding="utf-8") as f:
  124. config_state = json.load(f)
  125. config_states.restore_extension_config(config_state)
  126. startup_timer.record("restore extension config")
  127. elif config_state_file:
  128. print(f"!!! Config state backup not found: {config_state_file}")
  129. def validate_tls_options():
  130. if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
  131. return
  132. try:
  133. if not os.path.exists(cmd_opts.tls_keyfile):
  134. print("Invalid path to TLS keyfile given")
  135. if not os.path.exists(cmd_opts.tls_certfile):
  136. print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
  137. except TypeError:
  138. cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
  139. print("TLS setup invalid, running webui without TLS")
  140. else:
  141. print("Running with TLS")
  142. startup_timer.record("TLS")
  143. def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
  144. """
  145. Convert the gradio_auth and gradio_auth_path commandline arguments into
  146. an iterable of (username, password) tuples.
  147. """
  148. def process_credential_line(s) -> tuple[str, ...] | None:
  149. s = s.strip()
  150. if not s:
  151. return None
  152. return tuple(s.split(':', 1))
  153. if cmd_opts.gradio_auth:
  154. for cred in cmd_opts.gradio_auth.split(','):
  155. cred = process_credential_line(cred)
  156. if cred:
  157. yield cred
  158. if cmd_opts.gradio_auth_path:
  159. with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
  160. for line in file.readlines():
  161. for cred in line.strip().split(','):
  162. cred = process_credential_line(cred)
  163. if cred:
  164. yield cred
  165. def configure_sigint_handler():
  166. # make the program just exit at ctrl+c without waiting for anything
  167. def sigint_handler(sig, frame):
  168. print(f'Interrupted with signal {sig} in {frame}')
  169. os._exit(0)
  170. if not os.environ.get("COVERAGE_RUN"):
  171. # Don't install the immediate-quit handler when running under coverage,
  172. # as then the coverage report won't be generated.
  173. signal.signal(signal.SIGINT, sigint_handler)
  174. def configure_opts_onchange():
  175. shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
  176. shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
  177. shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
  178. shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
  179. shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
  180. shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
  181. startup_timer.record("opts onchange")
  182. def initialize():
  183. fix_asyncio_event_loop_policy()
  184. validate_tls_options()
  185. configure_sigint_handler()
  186. check_versions()
  187. modelloader.cleanup_models()
  188. configure_opts_onchange()
  189. modules.sd_models.setup_model()
  190. startup_timer.record("setup SD model")
  191. codeformer.setup_model(cmd_opts.codeformer_models_path)
  192. startup_timer.record("setup codeformer")
  193. gfpgan.setup_model(cmd_opts.gfpgan_models_path)
  194. startup_timer.record("setup gfpgan")
  195. initialize_rest(reload_script_modules=False)
  196. def initialize_rest(*, reload_script_modules=False):
  197. """
  198. Called both from initialize() and when reloading the webui.
  199. """
  200. sd_samplers.set_samplers()
  201. extensions.list_extensions()
  202. startup_timer.record("list extensions")
  203. restore_config_state_file()
  204. if cmd_opts.ui_debug_mode:
  205. shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
  206. modules.scripts.load_scripts()
  207. return
  208. modules.sd_models.list_models()
  209. startup_timer.record("list SD models")
  210. localization.list_localizations(cmd_opts.localizations_dir)
  211. with startup_timer.subcategory("load scripts"):
  212. modules.scripts.load_scripts()
  213. if reload_script_modules:
  214. for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
  215. importlib.reload(module)
  216. startup_timer.record("reload script modules")
  217. modelloader.load_upscalers()
  218. startup_timer.record("load upscalers")
  219. modules.sd_vae.refresh_vae_list()
  220. startup_timer.record("refresh VAE")
  221. modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
  222. startup_timer.record("refresh textual inversion templates")
  223. modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers)
  224. modules.sd_hijack.list_optimizers()
  225. startup_timer.record("scripts list_optimizers")
  226. modules.sd_unet.list_unets()
  227. startup_timer.record("scripts list_unets")
  228. def load_model():
  229. """
  230. Accesses shared.sd_model property to load model.
  231. After it's available, if it has been loaded before this access by some extension,
  232. its optimization may be None because the list of optimizaers has neet been filled
  233. by that time, so we apply optimization again.
  234. """
  235. shared.sd_model # noqa: B018
  236. if modules.sd_hijack.current_optimizer is None:
  237. modules.sd_hijack.apply_optimizations()
  238. Thread(target=load_model).start()
  239. Thread(target=devices.first_time_calculation).start()
  240. shared.reload_hypernetworks()
  241. startup_timer.record("reload hypernetworks")
  242. ui_extra_networks.initialize()
  243. ui_extra_networks.register_default_pages()
  244. extra_networks.initialize()
  245. extra_networks.register_default_extra_networks()
  246. startup_timer.record("initialize extra networks")
  247. def setup_middleware(app):
  248. app.middleware_stack = None # reset current middleware to allow modifying user provided list
  249. app.add_middleware(GZipMiddleware, minimum_size=1000)
  250. configure_cors_middleware(app)
  251. app.build_middleware_stack() # rebuild middleware stack on-the-fly
  252. def configure_cors_middleware(app):
  253. cors_options = {
  254. "allow_methods": ["*"],
  255. "allow_headers": ["*"],
  256. "allow_credentials": True,
  257. }
  258. if cmd_opts.cors_allow_origins:
  259. cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
  260. if cmd_opts.cors_allow_origins_regex:
  261. cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
  262. app.add_middleware(CORSMiddleware, **cors_options)
  263. def create_api(app):
  264. from modules.api.api import Api
  265. api = Api(app, queue_lock)
  266. return api
  267. def api_only():
  268. initialize()
  269. app = FastAPI()
  270. setup_middleware(app)
  271. api = create_api(app)
  272. modules.script_callbacks.app_started_callback(None, app)
  273. print(f"Startup time: {startup_timer.summary()}.")
  274. api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
  275. def stop_route(request):
  276. shared.state.server_command = "stop"
  277. return Response("Stopping.")
  278. def webui():
  279. launch_api = cmd_opts.api
  280. initialize()
  281. while 1:
  282. if shared.opts.clean_temp_dir_at_start:
  283. ui_tempdir.cleanup_tmpdr()
  284. startup_timer.record("cleanup temp dir")
  285. modules.script_callbacks.before_ui_callback()
  286. startup_timer.record("scripts before_ui_callback")
  287. shared.demo = modules.ui.create_ui()
  288. startup_timer.record("create ui")
  289. if not cmd_opts.no_gradio_queue:
  290. shared.demo.queue(64)
  291. gradio_auth_creds = list(get_gradio_auth_creds()) or None
  292. app, local_url, share_url = shared.demo.launch(
  293. share=cmd_opts.share,
  294. server_name=server_name,
  295. server_port=cmd_opts.port,
  296. ssl_keyfile=cmd_opts.tls_keyfile,
  297. ssl_certfile=cmd_opts.tls_certfile,
  298. ssl_verify=cmd_opts.disable_tls_verify,
  299. debug=cmd_opts.gradio_debug,
  300. auth=gradio_auth_creds,
  301. inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING ') != '1',
  302. prevent_thread_lock=True,
  303. allowed_paths=cmd_opts.gradio_allowed_path,
  304. app_kwargs={
  305. "docs_url": "/docs",
  306. "redoc_url": "/redoc",
  307. },
  308. )
  309. if cmd_opts.add_stop_route:
  310. app.add_route("/_stop", stop_route, methods=["POST"])
  311. # after initial launch, disable --autolaunch for subsequent restarts
  312. cmd_opts.autolaunch = False
  313. startup_timer.record("gradio launch")
  314. # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
  315. # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
  316. # running web ui and do whatever the attacker wants, including installing an extension and
  317. # running its code. We disable this here. Suggested by RyotaK.
  318. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
  319. setup_middleware(app)
  320. modules.progress.setup_progress_api(app)
  321. modules.ui.setup_ui_api(app)
  322. if launch_api:
  323. create_api(app)
  324. ui_extra_networks.add_pages_to_demo(app)
  325. startup_timer.record("add APIs")
  326. with startup_timer.subcategory("app_started_callback"):
  327. modules.script_callbacks.app_started_callback(shared.demo, app)
  328. timer.startup_record = startup_timer.dump()
  329. print(f"Startup time: {startup_timer.summary()}.")
  330. if cmd_opts.subpath:
  331. redirector = FastAPI()
  332. redirector.get("/")
  333. gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
  334. try:
  335. while True:
  336. server_command = shared.state.wait_for_server_command(timeout=5)
  337. if server_command:
  338. if server_command in ("stop", "restart"):
  339. break
  340. else:
  341. print(f"Unknown server command: {server_command}")
  342. except KeyboardInterrupt:
  343. print('Caught KeyboardInterrupt, stopping...')
  344. server_command = "stop"
  345. if server_command == "stop":
  346. print("Stopping server...")
  347. # If we catch a keyboard interrupt, we want to stop the server and exit.
  348. shared.demo.close()
  349. break
  350. print('Restarting UI...')
  351. shared.demo.close()
  352. time.sleep(0.5)
  353. startup_timer.reset()
  354. modules.script_callbacks.app_reload_callback()
  355. startup_timer.record("app reload callback")
  356. modules.script_callbacks.script_unloaded_callback()
  357. startup_timer.record("scripts unloaded callback")
  358. initialize_rest(reload_script_modules=True)
  359. if __name__ == "__main__":
  360. if cmd_opts.nowebui:
  361. api_only()
  362. else:
  363. webui()