hashes.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import hashlib
  2. import json
  3. import os.path
  4. import filelock
  5. from modules import shared
  6. from modules.paths import data_path
  7. cache_filename = os.path.join(data_path, "cache.json")
  8. cache_data = None
  9. def dump_cache():
  10. with filelock.FileLock(f"{cache_filename}.lock"):
  11. with open(cache_filename, "w", encoding="utf8") as file:
  12. json.dump(cache_data, file, indent=4)
  13. def cache(subsection):
  14. global cache_data
  15. if cache_data is None:
  16. with filelock.FileLock(f"{cache_filename}.lock"):
  17. if not os.path.isfile(cache_filename):
  18. cache_data = {}
  19. else:
  20. with open(cache_filename, "r", encoding="utf8") as file:
  21. cache_data = json.load(file)
  22. s = cache_data.get(subsection, {})
  23. cache_data[subsection] = s
  24. return s
  25. def calculate_sha256(filename):
  26. hash_sha256 = hashlib.sha256()
  27. blksize = 1024 * 1024
  28. with open(filename, "rb") as f:
  29. for chunk in iter(lambda: f.read(blksize), b""):
  30. hash_sha256.update(chunk)
  31. return hash_sha256.hexdigest()
  32. def sha256_from_cache(filename, title, use_addnet_hash=False):
  33. hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
  34. ondisk_mtime = os.path.getmtime(filename)
  35. if title not in hashes:
  36. return None
  37. cached_sha256 = hashes[title].get("sha256", None)
  38. cached_mtime = hashes[title].get("mtime", 0)
  39. if ondisk_mtime > cached_mtime or cached_sha256 is None:
  40. return None
  41. return cached_sha256
  42. def sha256(filename, title, use_addnet_hash=False):
  43. hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes")
  44. sha256_value = sha256_from_cache(filename, title, use_addnet_hash)
  45. if sha256_value is not None:
  46. return sha256_value
  47. if shared.cmd_opts.no_hashing:
  48. return None
  49. print(f"Calculating sha256 for {filename}: ", end='')
  50. if use_addnet_hash:
  51. with open(filename, "rb") as file:
  52. sha256_value = addnet_hash_safetensors(file)
  53. else:
  54. sha256_value = calculate_sha256(filename)
  55. print(f"{sha256_value}")
  56. hashes[title] = {
  57. "mtime": os.path.getmtime(filename),
  58. "sha256": sha256_value,
  59. }
  60. dump_cache()
  61. return sha256_value
  62. def addnet_hash_safetensors(b):
  63. """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py"""
  64. hash_sha256 = hashlib.sha256()
  65. blksize = 1024 * 1024
  66. b.seek(0)
  67. header = b.read(8)
  68. n = int.from_bytes(header, "little")
  69. offset = n + 8
  70. b.seek(offset)
  71. for chunk in iter(lambda: b.read(blksize), b""):
  72. hash_sha256.update(chunk)
  73. return hash_sha256.hexdigest()