ui_tempdir.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import os
  2. import tempfile
  3. from collections import namedtuple
  4. from pathlib import Path
  5. import gradio.components
  6. from PIL import PngImagePlugin
  7. from modules import shared
  8. Savedfile = namedtuple("Savedfile", ["name"])
  9. def register_tmp_file(gradio, filename):
  10. if hasattr(gradio, 'temp_file_sets'): # gradio 3.15
  11. gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
  12. if hasattr(gradio, 'temp_dirs'): # gradio 3.9
  13. gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
  14. def check_tmp_file(gradio, filename):
  15. if hasattr(gradio, 'temp_file_sets'):
  16. return any(filename in fileset for fileset in gradio.temp_file_sets)
  17. if hasattr(gradio, 'temp_dirs'):
  18. return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
  19. return False
  20. def save_pil_to_file(self, pil_image, dir=None, format="png"):
  21. already_saved_as = getattr(pil_image, 'already_saved_as', None)
  22. if already_saved_as and os.path.isfile(already_saved_as):
  23. register_tmp_file(shared.demo, already_saved_as)
  24. filename_with_mtime = f'{already_saved_as}?{os.path.getmtime(already_saved_as)}'
  25. register_tmp_file(shared.demo, filename_with_mtime)
  26. return filename_with_mtime
  27. if shared.opts.temp_dir != "":
  28. dir = shared.opts.temp_dir
  29. else:
  30. os.makedirs(dir, exist_ok=True)
  31. use_metadata = False
  32. metadata = PngImagePlugin.PngInfo()
  33. for key, value in pil_image.info.items():
  34. if isinstance(key, str) and isinstance(value, str):
  35. metadata.add_text(key, value)
  36. use_metadata = True
  37. file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
  38. pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
  39. return file_obj.name
  40. def install_ui_tempdir_override():
  41. """override save to file function so that it also writes PNG info"""
  42. gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
  43. def on_tmpdir_changed():
  44. if shared.opts.temp_dir == "" or shared.demo is None:
  45. return
  46. os.makedirs(shared.opts.temp_dir, exist_ok=True)
  47. register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
  48. def cleanup_tmpdr():
  49. temp_dir = shared.opts.temp_dir
  50. if temp_dir == "" or not os.path.isdir(temp_dir):
  51. return
  52. for root, _, files in os.walk(temp_dir, topdown=False):
  53. for name in files:
  54. _, extension = os.path.splitext(name)
  55. if extension != ".png":
  56. continue
  57. filename = os.path.join(root, name)
  58. os.remove(filename)
  59. def is_gradio_temp_path(path):
  60. """
  61. Check if the path is a temp dir used by gradio
  62. """
  63. path = Path(path)
  64. if shared.opts.temp_dir and path.is_relative_to(shared.opts.temp_dir):
  65. return True
  66. if gradio_temp_dir := os.environ.get("GRADIO_TEMP_DIR"):
  67. if path.is_relative_to(gradio_temp_dir):
  68. return True
  69. if path.is_relative_to(Path(tempfile.gettempdir()) / "gradio"):
  70. return True
  71. return False