extensions.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import os
  2. import sys
  3. import threading
  4. import traceback
  5. import git
  6. from modules import shared
  7. from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
  8. extensions = []
  9. if not os.path.exists(extensions_dir):
  10. os.makedirs(extensions_dir)
  11. def active():
  12. if shared.opts.disable_all_extensions == "all":
  13. return []
  14. elif shared.opts.disable_all_extensions == "extra":
  15. return [x for x in extensions if x.enabled and x.is_builtin]
  16. else:
  17. return [x for x in extensions if x.enabled]
  18. class Extension:
  19. lock = threading.Lock()
  20. def __init__(self, name, path, enabled=True, is_builtin=False):
  21. self.name = name
  22. self.path = path
  23. self.enabled = enabled
  24. self.status = ''
  25. self.can_update = False
  26. self.is_builtin = is_builtin
  27. self.commit_hash = ''
  28. self.commit_date = None
  29. self.version = ''
  30. self.branch = None
  31. self.remote = None
  32. self.have_info_from_repo = False
  33. def read_info_from_repo(self):
  34. if self.is_builtin or self.have_info_from_repo:
  35. return
  36. with self.lock:
  37. if self.have_info_from_repo:
  38. return
  39. self.do_read_info_from_repo()
  40. def do_read_info_from_repo(self):
  41. repo = None
  42. try:
  43. if os.path.exists(os.path.join(self.path, ".git")):
  44. repo = git.Repo(self.path)
  45. except Exception:
  46. print(f"Error reading github repository info from {self.path}:", file=sys.stderr)
  47. print(traceback.format_exc(), file=sys.stderr)
  48. if repo is None or repo.bare:
  49. self.remote = None
  50. else:
  51. try:
  52. self.status = 'unknown'
  53. self.remote = next(repo.remote().urls, None)
  54. self.commit_date = repo.head.commit.committed_date
  55. if repo.active_branch:
  56. self.branch = repo.active_branch.name
  57. self.commit_hash = repo.head.commit.hexsha
  58. self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
  59. except Exception as ex:
  60. print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
  61. self.remote = None
  62. self.have_info_from_repo = True
  63. def list_files(self, subdir, extension):
  64. from modules import scripts
  65. dirpath = os.path.join(self.path, subdir)
  66. if not os.path.isdir(dirpath):
  67. return []
  68. res = []
  69. for filename in sorted(os.listdir(dirpath)):
  70. res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
  71. res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
  72. return res
  73. def check_updates(self):
  74. repo = git.Repo(self.path)
  75. for fetch in repo.remote().fetch(dry_run=True):
  76. if fetch.flags != fetch.HEAD_UPTODATE:
  77. self.can_update = True
  78. self.status = "new commits"
  79. return
  80. try:
  81. origin = repo.rev_parse('origin')
  82. if repo.head.commit != origin:
  83. self.can_update = True
  84. self.status = "behind HEAD"
  85. return
  86. except Exception:
  87. self.can_update = False
  88. self.status = "unknown (remote error)"
  89. return
  90. self.can_update = False
  91. self.status = "latest"
  92. def fetch_and_reset_hard(self, commit='origin'):
  93. repo = git.Repo(self.path)
  94. # Fix: `error: Your local changes to the following files would be overwritten by merge`,
  95. # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
  96. repo.git.fetch(all=True)
  97. repo.git.reset(commit, hard=True)
  98. self.have_info_from_repo = False
  99. def list_extensions():
  100. extensions.clear()
  101. if not os.path.isdir(extensions_dir):
  102. return
  103. if shared.opts.disable_all_extensions == "all":
  104. print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
  105. elif shared.opts.disable_all_extensions == "extra":
  106. print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
  107. extension_paths = []
  108. for dirname in [extensions_dir, extensions_builtin_dir]:
  109. if not os.path.isdir(dirname):
  110. return
  111. for extension_dirname in sorted(os.listdir(dirname)):
  112. path = os.path.join(dirname, extension_dirname)
  113. if not os.path.isdir(path):
  114. continue
  115. extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
  116. for dirname, path, is_builtin in extension_paths:
  117. extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
  118. extensions.append(extension)