shared_init.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. import torch
  3. from modules import shared
  4. from modules.shared import cmd_opts
  5. def initialize():
  6. """Initializes fields inside the shared module in a controlled manner.
  7. Should be called early because some other modules you can import mingt need these fields to be already set.
  8. """
  9. os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
  10. from modules import options, shared_options
  11. shared.options_templates = shared_options.options_templates
  12. shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
  13. shared.restricted_opts = shared_options.restricted_opts
  14. try:
  15. shared.opts.load(shared.config_filename)
  16. except FileNotFoundError:
  17. pass
  18. from modules import devices
  19. devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
  20. (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
  21. devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
  22. devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
  23. devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype
  24. shared.device = devices.device
  25. shared.weight_load_location = None if cmd_opts.lowram else "cpu"
  26. from modules import shared_state
  27. shared.state = shared_state.State()
  28. from modules import styles
  29. shared.prompt_styles = styles.StyleDatabase(shared.styles_filename)
  30. from modules import interrogate
  31. shared.interrogator = interrogate.InterrogateModels("interrogate")
  32. from modules import shared_total_tqdm
  33. shared.total_tqdm = shared_total_tqdm.TotalTQDM()
  34. from modules import memmon, devices
  35. shared.mem_mon = memmon.MemUsageMonitor("MemMon", devices.device, shared.opts)
  36. shared.mem_mon.start()