launch.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # this scripts installs necessary requirements and launches main program in webui.py
  2. import subprocess
  3. import os
  4. import sys
  5. import importlib.util
  6. import shlex
  7. import platform
  8. import argparse
  9. import json
  10. dir_repos = "repositories"
  11. dir_extensions = "extensions"
  12. python = sys.executable
  13. git = os.environ.get('GIT', "git")
  14. index_url = os.environ.get('INDEX_URL', "")
  15. def extract_arg(args, name):
  16. return [x for x in args if x != name], name in args
  17. def extract_opt(args, name):
  18. opt = None
  19. is_present = False
  20. if name in args:
  21. is_present = True
  22. idx = args.index(name)
  23. del args[idx]
  24. if idx < len(args) and args[idx][0] != "-":
  25. opt = args[idx]
  26. del args[idx]
  27. return args, is_present, opt
  28. def run(command, desc=None, errdesc=None, custom_env=None):
  29. if desc is not None:
  30. print(desc)
  31. result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
  32. if result.returncode != 0:
  33. message = f"""{errdesc or 'Error running command'}.
  34. Command: {command}
  35. Error code: {result.returncode}
  36. stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
  37. stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
  38. """
  39. raise RuntimeError(message)
  40. return result.stdout.decode(encoding="utf8", errors="ignore")
  41. def check_run(command):
  42. result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
  43. return result.returncode == 0
  44. def is_installed(package):
  45. try:
  46. spec = importlib.util.find_spec(package)
  47. except ModuleNotFoundError:
  48. return False
  49. return spec is not None
  50. def repo_dir(name):
  51. return os.path.join(dir_repos, name)
  52. def run_python(code, desc=None, errdesc=None):
  53. return run(f'"{python}" -c "{code}"', desc, errdesc)
  54. def run_pip(args, desc=None):
  55. index_url_line = f' --index-url {index_url}' if index_url != '' else ''
  56. return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
  57. def check_run_python(code):
  58. return check_run(f'"{python}" -c "{code}"')
  59. def git_clone(url, dir, name, commithash=None):
  60. # TODO clone into temporary dir and move if successful
  61. if os.path.exists(dir):
  62. if commithash is None:
  63. return
  64. current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
  65. if current_hash == commithash:
  66. return
  67. run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
  68. run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
  69. return
  70. run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
  71. if commithash is not None:
  72. run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
  73. def version_check(commit):
  74. try:
  75. import requests
  76. commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()
  77. if commit != "<none>" and commits['commit']['sha'] != commit:
  78. print("--------------------------------------------------------")
  79. print("| You are not up to date with the most recent release. |")
  80. print("| Consider running `git pull` to update. |")
  81. print("--------------------------------------------------------")
  82. elif commits['commit']['sha'] == commit:
  83. print("You are up to date with the most recent release.")
  84. else:
  85. print("Not a git clone, can't perform version check.")
  86. except Exception as e:
  87. print("version check failed", e)
  88. def run_extension_installer(extension_dir):
  89. path_installer = os.path.join(extension_dir, "install.py")
  90. if not os.path.isfile(path_installer):
  91. return
  92. try:
  93. env = os.environ.copy()
  94. env['PYTHONPATH'] = os.path.abspath(".")
  95. print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
  96. except Exception as e:
  97. print(e, file=sys.stderr)
  98. def list_extensions(settings_file):
  99. settings = {}
  100. try:
  101. if os.path.isfile(settings_file):
  102. with open(settings_file, "r", encoding="utf8") as file:
  103. settings = json.load(file)
  104. except Exception as e:
  105. print(e, file=sys.stderr)
  106. disabled_extensions = set(settings.get('disabled_extensions', []))
  107. return [x for x in os.listdir(dir_extensions) if x not in disabled_extensions]
  108. def run_extensions_installers(settings_file):
  109. if not os.path.isdir(dir_extensions):
  110. return
  111. for dirname_extension in list_extensions(settings_file):
  112. run_extension_installer(os.path.join(dir_extensions, dirname_extension))
  113. def prepare_enviroment():
  114. torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
  115. requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
  116. commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
  117. gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
  118. clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
  119. openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
  120. xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
  121. stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
  122. taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
  123. k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
  124. codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
  125. blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
  126. stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "47b6b607fdd31875c9279cd2f4f16b92e4ea958e")
  127. taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
  128. k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "5b3af030dd83e0297272d861c19477735d0317ec")
  129. codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
  130. blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
  131. sys.argv += shlex.split(commandline_args)
  132. parser = argparse.ArgumentParser()
  133. parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default='config.json')
  134. args, _ = parser.parse_known_args(sys.argv)
  135. sys.argv, _ = extract_arg(sys.argv, '-f')
  136. sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
  137. sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
  138. sys.argv, update_check = extract_arg(sys.argv, '--update-check')
  139. sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
  140. xformers = '--xformers' in sys.argv
  141. ngrok = '--ngrok' in sys.argv
  142. try:
  143. commit = run(f"{git} rev-parse HEAD").strip()
  144. except Exception:
  145. commit = "<none>"
  146. print(f"Python {sys.version}")
  147. print(f"Commit hash: {commit}")
  148. if not is_installed("torch") or not is_installed("torchvision"):
  149. run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
  150. if not skip_torch_cuda_test:
  151. run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
  152. if not is_installed("gfpgan"):
  153. run_pip(f"install {gfpgan_package}", "gfpgan")
  154. if not is_installed("clip"):
  155. run_pip(f"install {clip_package}", "clip")
  156. if not is_installed("open_clip"):
  157. run_pip(f"install {openclip_package}", "open_clip")
  158. if (not is_installed("xformers") or reinstall_xformers) and xformers:
  159. if platform.system() == "Windows":
  160. if platform.python_version().startswith("3.10"):
  161. run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers")
  162. else:
  163. print("Installation of xformers is not supported in this version of Python.")
  164. print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
  165. if not is_installed("xformers"):
  166. exit(0)
  167. elif platform.system() == "Linux":
  168. run_pip("install xformers", "xformers")
  169. if not is_installed("pyngrok") and ngrok:
  170. run_pip("install pyngrok", "ngrok")
  171. os.makedirs(dir_repos, exist_ok=True)
  172. git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
  173. git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
  174. git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
  175. git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
  176. git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
  177. if not is_installed("lpips"):
  178. run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
  179. run_pip(f"install -r {requirements_file}", "requirements for Web UI")
  180. run_extensions_installers(settings_file=args.ui_settings_file)
  181. if update_check:
  182. version_check(commit)
  183. if "--exit" in sys.argv:
  184. print("Exiting because of --exit argument")
  185. exit(0)
  186. if run_tests:
  187. exitcode = tests(test_dir)
  188. exit(exitcode)
  189. def tests(test_dir):
  190. if "--api" not in sys.argv:
  191. sys.argv.append("--api")
  192. if "--ckpt" not in sys.argv:
  193. sys.argv.append("--ckpt")
  194. sys.argv.append("./test/test_files/empty.pt")
  195. if "--skip-torch-cuda-test" not in sys.argv:
  196. sys.argv.append("--skip-torch-cuda-test")
  197. print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")
  198. with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
  199. proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr)
  200. import test.server_poll
  201. exitcode = test.server_poll.run_tests(proc, test_dir)
  202. print(f"Stopping Web UI process with id {proc.pid}")
  203. proc.kill()
  204. return exitcode
  205. def start():
  206. print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
  207. import webui
  208. if '--nowebui' in sys.argv:
  209. webui.api_only()
  210. else:
  211. webui.webui()
  212. if __name__ == "__main__":
  213. prepare_enviroment()
  214. start()