config_states.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. """
  2. Supports saving and restoring webui and extensions from a known working set of commits
  3. """
  4. import os
  5. import json
  6. import time
  7. import tqdm
  8. from datetime import datetime
  9. import git
  10. from modules import shared, extensions, errors
  11. from modules.paths_internal import script_path, config_states_dir
  12. all_config_states = {}
  13. def list_config_states():
  14. global all_config_states
  15. all_config_states.clear()
  16. os.makedirs(config_states_dir, exist_ok=True)
  17. config_states = []
  18. for filename in os.listdir(config_states_dir):
  19. if filename.endswith(".json"):
  20. path = os.path.join(config_states_dir, filename)
  21. try:
  22. with open(path, "r", encoding="utf-8") as f:
  23. j = json.load(f)
  24. assert "created_at" in j, '"created_at" does not exist'
  25. j["filepath"] = path
  26. config_states.append(j)
  27. except Exception as e:
  28. print(f'[ERROR]: Config states {path}, {e}')
  29. config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
  30. for cs in config_states:
  31. timestamp = time.asctime(time.gmtime(cs["created_at"]))
  32. name = cs.get("name", "Config")
  33. full_name = f"{name}: {timestamp}"
  34. all_config_states[full_name] = cs
  35. return all_config_states
  36. def get_webui_config():
  37. webui_repo = None
  38. try:
  39. if os.path.exists(os.path.join(script_path, ".git")):
  40. webui_repo = git.Repo(script_path)
  41. except Exception:
  42. errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
  43. webui_remote = None
  44. webui_commit_hash = None
  45. webui_commit_date = None
  46. webui_branch = None
  47. if webui_repo and not webui_repo.bare:
  48. try:
  49. webui_remote = next(webui_repo.remote().urls, None)
  50. head = webui_repo.head.commit
  51. webui_commit_date = webui_repo.head.commit.committed_date
  52. webui_commit_hash = head.hexsha
  53. webui_branch = webui_repo.active_branch.name
  54. except Exception:
  55. webui_remote = None
  56. return {
  57. "remote": webui_remote,
  58. "commit_hash": webui_commit_hash,
  59. "commit_date": webui_commit_date,
  60. "branch": webui_branch,
  61. }
  62. def get_extension_config():
  63. ext_config = {}
  64. for ext in extensions.extensions:
  65. ext.read_info_from_repo()
  66. entry = {
  67. "name": ext.name,
  68. "path": ext.path,
  69. "enabled": ext.enabled,
  70. "is_builtin": ext.is_builtin,
  71. "remote": ext.remote,
  72. "commit_hash": ext.commit_hash,
  73. "commit_date": ext.commit_date,
  74. "branch": ext.branch,
  75. "have_info_from_repo": ext.have_info_from_repo
  76. }
  77. ext_config[ext.name] = entry
  78. return ext_config
  79. def get_config():
  80. creation_time = datetime.now().timestamp()
  81. webui_config = get_webui_config()
  82. ext_config = get_extension_config()
  83. return {
  84. "created_at": creation_time,
  85. "webui": webui_config,
  86. "extensions": ext_config
  87. }
  88. def restore_webui_config(config):
  89. print("* Restoring webui state...")
  90. if "webui" not in config:
  91. print("Error: No webui data saved to config")
  92. return
  93. webui_config = config["webui"]
  94. if "commit_hash" not in webui_config:
  95. print("Error: No commit saved to webui config")
  96. return
  97. webui_commit_hash = webui_config.get("commit_hash", None)
  98. webui_repo = None
  99. try:
  100. if os.path.exists(os.path.join(script_path, ".git")):
  101. webui_repo = git.Repo(script_path)
  102. except Exception:
  103. errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
  104. return
  105. try:
  106. webui_repo.git.fetch(all=True)
  107. webui_repo.git.reset(webui_commit_hash, hard=True)
  108. print(f"* Restored webui to commit {webui_commit_hash}.")
  109. except Exception:
  110. errors.report(f"Error restoring webui to commit{webui_commit_hash}")
  111. def restore_extension_config(config):
  112. print("* Restoring extension state...")
  113. if "extensions" not in config:
  114. print("Error: No extension data saved to config")
  115. return
  116. ext_config = config["extensions"]
  117. results = []
  118. disabled = []
  119. for ext in tqdm.tqdm(extensions.extensions):
  120. if ext.is_builtin:
  121. continue
  122. ext.read_info_from_repo()
  123. current_commit = ext.commit_hash
  124. if ext.name not in ext_config:
  125. ext.disabled = True
  126. disabled.append(ext.name)
  127. results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
  128. continue
  129. entry = ext_config[ext.name]
  130. if "commit_hash" in entry and entry["commit_hash"]:
  131. try:
  132. ext.fetch_and_reset_hard(entry["commit_hash"])
  133. ext.read_info_from_repo()
  134. if current_commit != entry["commit_hash"]:
  135. results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
  136. except Exception as ex:
  137. results.append((ext, current_commit[:8], False, ex))
  138. else:
  139. results.append((ext, current_commit[:8], False, "No commit hash found in config"))
  140. if not entry.get("enabled", False):
  141. ext.disabled = True
  142. disabled.append(ext.name)
  143. else:
  144. ext.disabled = False
  145. shared.opts.disabled_extensions = disabled
  146. shared.opts.save(shared.config_filename)
  147. print("* Finished restoring extensions. Results:")
  148. for ext, prev_commit, success, result in results:
  149. if success:
  150. print(f" + {ext.name}: {prev_commit} -> {result}")
  151. else:
  152. print(f" ! {ext.name}: FAILURE ({result})")