config_states.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. """
  2. Supports saving and restoring webui and extensions from a known working set of commits
  3. """
  4. import os
  5. import json
  6. import tqdm
  7. from datetime import datetime
  8. import git
  9. from modules import shared, extensions, errors
  10. from modules.paths_internal import script_path, config_states_dir
  11. all_config_states = {}
  12. def list_config_states():
  13. global all_config_states
  14. all_config_states.clear()
  15. os.makedirs(config_states_dir, exist_ok=True)
  16. config_states = []
  17. for filename in os.listdir(config_states_dir):
  18. if filename.endswith(".json"):
  19. path = os.path.join(config_states_dir, filename)
  20. try:
  21. with open(path, "r", encoding="utf-8") as f:
  22. j = json.load(f)
  23. assert "created_at" in j, '"created_at" does not exist'
  24. j["filepath"] = path
  25. config_states.append(j)
  26. except Exception as e:
  27. print(f'[ERROR]: Config states {path}, {e}')
  28. config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
  29. for cs in config_states:
  30. timestamp = datetime.fromtimestamp(cs["created_at"]).strftime('%Y-%m-%d %H:%M:%S')
  31. name = cs.get("name", "Config")
  32. full_name = f"{name}: {timestamp}"
  33. all_config_states[full_name] = cs
  34. return all_config_states
  35. def get_webui_config():
  36. webui_repo = None
  37. try:
  38. if os.path.exists(os.path.join(script_path, ".git")):
  39. webui_repo = git.Repo(script_path)
  40. except Exception:
  41. errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
  42. webui_remote = None
  43. webui_commit_hash = None
  44. webui_commit_date = None
  45. webui_branch = None
  46. if webui_repo and not webui_repo.bare:
  47. try:
  48. webui_remote = next(webui_repo.remote().urls, None)
  49. head = webui_repo.head.commit
  50. webui_commit_date = webui_repo.head.commit.committed_date
  51. webui_commit_hash = head.hexsha
  52. webui_branch = webui_repo.active_branch.name
  53. except Exception:
  54. webui_remote = None
  55. return {
  56. "remote": webui_remote,
  57. "commit_hash": webui_commit_hash,
  58. "commit_date": webui_commit_date,
  59. "branch": webui_branch,
  60. }
  61. def get_extension_config():
  62. ext_config = {}
  63. for ext in extensions.extensions:
  64. ext.read_info_from_repo()
  65. entry = {
  66. "name": ext.name,
  67. "path": ext.path,
  68. "enabled": ext.enabled,
  69. "is_builtin": ext.is_builtin,
  70. "remote": ext.remote,
  71. "commit_hash": ext.commit_hash,
  72. "commit_date": ext.commit_date,
  73. "branch": ext.branch,
  74. "have_info_from_repo": ext.have_info_from_repo
  75. }
  76. ext_config[ext.name] = entry
  77. return ext_config
  78. def get_config():
  79. creation_time = datetime.now().timestamp()
  80. webui_config = get_webui_config()
  81. ext_config = get_extension_config()
  82. return {
  83. "created_at": creation_time,
  84. "webui": webui_config,
  85. "extensions": ext_config
  86. }
  87. def restore_webui_config(config):
  88. print("* Restoring webui state...")
  89. if "webui" not in config:
  90. print("Error: No webui data saved to config")
  91. return
  92. webui_config = config["webui"]
  93. if "commit_hash" not in webui_config:
  94. print("Error: No commit saved to webui config")
  95. return
  96. webui_commit_hash = webui_config.get("commit_hash", None)
  97. webui_repo = None
  98. try:
  99. if os.path.exists(os.path.join(script_path, ".git")):
  100. webui_repo = git.Repo(script_path)
  101. except Exception:
  102. errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
  103. return
  104. try:
  105. webui_repo.git.fetch(all=True)
  106. webui_repo.git.reset(webui_commit_hash, hard=True)
  107. print(f"* Restored webui to commit {webui_commit_hash}.")
  108. except Exception:
  109. errors.report(f"Error restoring webui to commit{webui_commit_hash}")
  110. def restore_extension_config(config):
  111. print("* Restoring extension state...")
  112. if "extensions" not in config:
  113. print("Error: No extension data saved to config")
  114. return
  115. ext_config = config["extensions"]
  116. results = []
  117. disabled = []
  118. for ext in tqdm.tqdm(extensions.extensions):
  119. if ext.is_builtin:
  120. continue
  121. ext.read_info_from_repo()
  122. current_commit = ext.commit_hash
  123. if ext.name not in ext_config:
  124. ext.disabled = True
  125. disabled.append(ext.name)
  126. results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
  127. continue
  128. entry = ext_config[ext.name]
  129. if "commit_hash" in entry and entry["commit_hash"]:
  130. try:
  131. ext.fetch_and_reset_hard(entry["commit_hash"])
  132. ext.read_info_from_repo()
  133. if current_commit != entry["commit_hash"]:
  134. results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
  135. except Exception as ex:
  136. results.append((ext, current_commit[:8], False, ex))
  137. else:
  138. results.append((ext, current_commit[:8], False, "No commit hash found in config"))
  139. if not entry.get("enabled", False):
  140. ext.disabled = True
  141. disabled.append(ext.name)
  142. else:
  143. ext.disabled = False
  144. shared.opts.disabled_extensions = disabled
  145. shared.opts.save(shared.config_filename)
  146. print("* Finished restoring extensions. Results:")
  147. for ext, prev_commit, success, result in results:
  148. if success:
  149. print(f" + {ext.name}: {prev_commit} -> {result}")
  150. else:
  151. print(f" ! {ext.name}: FAILURE ({result})")