webui.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. import os
  2. import sys
  3. import time
  4. import importlib
  5. import signal
  6. import re
  7. import warnings
  8. import json
  9. from threading import Thread
  10. from fastapi import FastAPI, Response
  11. from fastapi.middleware.cors import CORSMiddleware
  12. from fastapi.middleware.gzip import GZipMiddleware
  13. from packaging import version
  14. import logging
  15. logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
  16. from modules import paths, timer, import_hook, errors # noqa: F401
  17. startup_timer = timer.Timer()
  18. import torch
  19. import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
  20. warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
  21. warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
  22. startup_timer.record("import torch")
  23. import gradio
  24. startup_timer.record("import gradio")
  25. import ldm.modules.encoders.modules # noqa: F401
  26. startup_timer.record("import ldm")
  27. from modules import extra_networks, ui_extra_networks_checkpoints
  28. from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
  29. from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
  30. # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
  31. if ".dev" in torch.__version__ or "+git" in torch.__version__:
  32. torch.__long_version__ = torch.__version__
  33. torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
  34. from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
  35. import modules.codeformer_model as codeformer
  36. import modules.face_restoration
  37. import modules.gfpgan_model as gfpgan
  38. import modules.img2img
  39. import modules.lowvram
  40. import modules.scripts
  41. import modules.sd_hijack
  42. import modules.sd_models
  43. import modules.sd_vae
  44. import modules.txt2img
  45. import modules.script_callbacks
  46. import modules.textual_inversion.textual_inversion
  47. import modules.progress
  48. import modules.ui
  49. from modules import modelloader
  50. from modules.shared import cmd_opts
  51. import modules.hypernetworks.hypernetwork
  52. startup_timer.record("other imports")
  53. if cmd_opts.server_name:
  54. server_name = cmd_opts.server_name
  55. else:
  56. server_name = "0.0.0.0" if cmd_opts.listen else None
  57. def fix_asyncio_event_loop_policy():
  58. """
  59. The default `asyncio` event loop policy only automatically creates
  60. event loops in the main threads. Other threads must create event
  61. loops explicitly or `asyncio.get_event_loop` (and therefore
  62. `.IOLoop.current`) will fail. Installing this policy allows event
  63. loops to be created automatically on any thread, matching the
  64. behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
  65. """
  66. import asyncio
  67. if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
  68. # "Any thread" and "selector" should be orthogonal, but there's not a clean
  69. # interface for composing policies so pick the right base.
  70. _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
  71. else:
  72. _BasePolicy = asyncio.DefaultEventLoopPolicy
  73. class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
  74. """Event loop policy that allows loop creation on any thread.
  75. Usage::
  76. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  77. """
  78. def get_event_loop(self) -> asyncio.AbstractEventLoop:
  79. try:
  80. return super().get_event_loop()
  81. except (RuntimeError, AssertionError):
  82. # This was an AssertionError in python 3.4.2 (which ships with debian jessie)
  83. # and changed to a RuntimeError in 3.4.3.
  84. # "There is no current event loop in thread %r"
  85. loop = self.new_event_loop()
  86. self.set_event_loop(loop)
  87. return loop
  88. asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
  89. def check_versions():
  90. if shared.cmd_opts.skip_version_check:
  91. return
  92. expected_torch_version = "2.0.0"
  93. if version.parse(torch.__version__) < version.parse(expected_torch_version):
  94. errors.print_error_explanation(f"""
  95. You are running torch {torch.__version__}.
  96. The program is tested to work with torch {expected_torch_version}.
  97. To reinstall the desired version, run with commandline flag --reinstall-torch.
  98. Beware that this will cause a lot of large files to be downloaded, as well as
  99. there are reports of issues with training tab on the latest version.
  100. Use --skip-version-check commandline argument to disable this check.
  101. """.strip())
  102. expected_xformers_version = "0.0.17"
  103. if shared.xformers_available:
  104. import xformers
  105. if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
  106. errors.print_error_explanation(f"""
  107. You are running xformers {xformers.__version__}.
  108. The program is tested to work with xformers {expected_xformers_version}.
  109. To reinstall the desired version, run with commandline flag --reinstall-xformers.
  110. Use --skip-version-check commandline argument to disable this check.
  111. """.strip())
  112. def restore_config_state_file():
  113. config_state_file = shared.opts.restore_config_state_file
  114. if config_state_file == "":
  115. return
  116. shared.opts.restore_config_state_file = ""
  117. shared.opts.save(shared.config_filename)
  118. if os.path.isfile(config_state_file):
  119. print(f"*** About to restore extension state from file: {config_state_file}")
  120. with open(config_state_file, "r", encoding="utf-8") as f:
  121. config_state = json.load(f)
  122. config_states.restore_extension_config(config_state)
  123. startup_timer.record("restore extension config")
  124. elif config_state_file:
  125. print(f"!!! Config state backup not found: {config_state_file}")
  126. def initialize():
  127. fix_asyncio_event_loop_policy()
  128. check_versions()
  129. extensions.list_extensions()
  130. localization.list_localizations(cmd_opts.localizations_dir)
  131. startup_timer.record("list extensions")
  132. restore_config_state_file()
  133. if cmd_opts.ui_debug_mode:
  134. shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
  135. modules.scripts.load_scripts()
  136. return
  137. modelloader.cleanup_models()
  138. modules.sd_models.setup_model()
  139. startup_timer.record("list SD models")
  140. codeformer.setup_model(cmd_opts.codeformer_models_path)
  141. startup_timer.record("setup codeformer")
  142. gfpgan.setup_model(cmd_opts.gfpgan_models_path)
  143. startup_timer.record("setup gfpgan")
  144. modules.scripts.load_scripts()
  145. startup_timer.record("load scripts")
  146. modelloader.load_upscalers()
  147. startup_timer.record("load upscalers")
  148. modules.sd_vae.refresh_vae_list()
  149. startup_timer.record("refresh VAE")
  150. modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
  151. startup_timer.record("refresh textual inversion templates")
  152. # load model in parallel to other startup stuff
  153. Thread(target=lambda: shared.sd_model).start()
  154. shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
  155. shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
  156. shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
  157. shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
  158. shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
  159. startup_timer.record("opts onchange")
  160. shared.reload_hypernetworks()
  161. startup_timer.record("reload hypernets")
  162. ui_extra_networks.intialize()
  163. ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
  164. ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
  165. ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
  166. extra_networks.initialize()
  167. extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
  168. startup_timer.record("extra networks")
  169. if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
  170. try:
  171. if not os.path.exists(cmd_opts.tls_keyfile):
  172. print("Invalid path to TLS keyfile given")
  173. if not os.path.exists(cmd_opts.tls_certfile):
  174. print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
  175. except TypeError:
  176. cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
  177. print("TLS setup invalid, running webui without TLS")
  178. else:
  179. print("Running with TLS")
  180. startup_timer.record("TLS")
  181. # make the program just exit at ctrl+c without waiting for anything
  182. def sigint_handler(sig, frame):
  183. print(f'Interrupted with signal {sig} in {frame}')
  184. os._exit(0)
  185. if not os.environ.get("COVERAGE_RUN"):
  186. # Don't install the immediate-quit handler when running under coverage,
  187. # as then the coverage report won't be generated.
  188. signal.signal(signal.SIGINT, sigint_handler)
  189. def setup_middleware(app):
  190. app.middleware_stack = None # reset current middleware to allow modifying user provided list
  191. app.add_middleware(GZipMiddleware, minimum_size=1000)
  192. if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
  193. app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
  194. elif cmd_opts.cors_allow_origins:
  195. app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
  196. elif cmd_opts.cors_allow_origins_regex:
  197. app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
  198. app.build_middleware_stack() # rebuild middleware stack on-the-fly
  199. def create_api(app):
  200. from modules.api.api import Api
  201. api = Api(app, queue_lock)
  202. return api
  203. def api_only():
  204. initialize()
  205. app = FastAPI()
  206. setup_middleware(app)
  207. api = create_api(app)
  208. modules.script_callbacks.app_started_callback(None, app)
  209. print(f"Startup time: {startup_timer.summary()}.")
  210. api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
  211. def stop_route(request):
  212. shared.state.server_command = "stop"
  213. return Response("Stopping.")
  214. def webui():
  215. launch_api = cmd_opts.api
  216. initialize()
  217. while 1:
  218. if shared.opts.clean_temp_dir_at_start:
  219. ui_tempdir.cleanup_tmpdr()
  220. startup_timer.record("cleanup temp dir")
  221. modules.script_callbacks.before_ui_callback()
  222. startup_timer.record("scripts before_ui_callback")
  223. shared.demo = modules.ui.create_ui()
  224. startup_timer.record("create ui")
  225. if not cmd_opts.no_gradio_queue:
  226. shared.demo.queue(64)
  227. gradio_auth_creds = []
  228. if cmd_opts.gradio_auth:
  229. gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
  230. if cmd_opts.gradio_auth_path:
  231. with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
  232. for line in file.readlines():
  233. gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
  234. # this restores the missing /docs endpoint
  235. if launch_api and not hasattr(FastAPI, 'original_setup'):
  236. def fastapi_setup(self):
  237. self.docs_url = "/docs"
  238. self.redoc_url = "/redoc"
  239. self.original_setup()
  240. FastAPI.original_setup = FastAPI.setup
  241. FastAPI.setup = fastapi_setup
  242. app, local_url, share_url = shared.demo.launch(
  243. share=cmd_opts.share,
  244. server_name=server_name,
  245. server_port=cmd_opts.port,
  246. ssl_keyfile=cmd_opts.tls_keyfile,
  247. ssl_certfile=cmd_opts.tls_certfile,
  248. ssl_verify=cmd_opts.disable_tls_verify,
  249. debug=cmd_opts.gradio_debug,
  250. auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
  251. inbrowser=cmd_opts.autolaunch,
  252. prevent_thread_lock=True,
  253. allowed_paths=cmd_opts.gradio_allowed_path,
  254. )
  255. if cmd_opts.add_stop_route:
  256. app.add_route("/_stop", stop_route, methods=["POST"])
  257. # after initial launch, disable --autolaunch for subsequent restarts
  258. cmd_opts.autolaunch = False
  259. startup_timer.record("gradio launch")
  260. # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
  261. # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
  262. # running web ui and do whatever the attacker wants, including installing an extension and
  263. # running its code. We disable this here. Suggested by RyotaK.
  264. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
  265. setup_middleware(app)
  266. modules.progress.setup_progress_api(app)
  267. modules.ui.setup_ui_api(app)
  268. if launch_api:
  269. create_api(app)
  270. ui_extra_networks.add_pages_to_demo(app)
  271. modules.script_callbacks.app_started_callback(shared.demo, app)
  272. startup_timer.record("scripts app_started_callback")
  273. print(f"Startup time: {startup_timer.summary()}.")
  274. if cmd_opts.subpath:
  275. redirector = FastAPI()
  276. redirector.get("/")
  277. gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
  278. try:
  279. while True:
  280. server_command = shared.state.wait_for_server_command(timeout=5)
  281. if server_command:
  282. if server_command in ("stop", "restart"):
  283. break
  284. else:
  285. print(f"Unknown server command: {server_command}")
  286. except KeyboardInterrupt:
  287. print('Caught KeyboardInterrupt, stopping...')
  288. server_command = "stop"
  289. if server_command == "stop":
  290. print("Stopping server...")
  291. # If we catch a keyboard interrupt, we want to stop the server and exit.
  292. shared.demo.close()
  293. break
  294. print('Restarting UI...')
  295. shared.demo.close()
  296. time.sleep(0.5)
  297. modules.script_callbacks.app_reload_callback()
  298. startup_timer.reset()
  299. sd_samplers.set_samplers()
  300. modules.script_callbacks.script_unloaded_callback()
  301. extensions.list_extensions()
  302. startup_timer.record("list extensions")
  303. restore_config_state_file()
  304. localization.list_localizations(cmd_opts.localizations_dir)
  305. modules.scripts.reload_scripts()
  306. startup_timer.record("load scripts")
  307. modules.script_callbacks.model_loaded_callback(shared.sd_model)
  308. startup_timer.record("model loaded callback")
  309. modelloader.load_upscalers()
  310. startup_timer.record("load upscalers")
  311. for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
  312. importlib.reload(module)
  313. startup_timer.record("reload script modules")
  314. modules.sd_models.list_models()
  315. startup_timer.record("list SD models")
  316. shared.reload_hypernetworks()
  317. startup_timer.record("reload hypernetworks")
  318. ui_extra_networks.intialize()
  319. ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
  320. ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
  321. ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
  322. extra_networks.initialize()
  323. extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
  324. startup_timer.record("initialize extra networks")
  325. if __name__ == "__main__":
  326. if cmd_opts.nowebui:
  327. api_only()
  328. else:
  329. webui()