codeformer_model.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from __future__ import annotations
  2. import logging
  3. import torch
  4. from modules import (
  5. devices,
  6. errors,
  7. face_restoration,
  8. face_restoration_utils,
  9. modelloader,
  10. shared,
  11. )
  12. logger = logging.getLogger(__name__)
  13. model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
  14. model_download_name = 'codeformer-v0.1.0.pth'
  15. # used by e.g. postprocessing_codeformer.py
  16. codeformer: face_restoration.FaceRestoration | None = None
  17. class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
  18. def name(self):
  19. return "CodeFormer"
  20. def load_net(self) -> torch.Module:
  21. for model_path in modelloader.load_models(
  22. model_path=self.model_path,
  23. model_url=model_url,
  24. command_path=self.model_path,
  25. download_name=model_download_name,
  26. ext_filter=['.pth'],
  27. ):
  28. return modelloader.load_spandrel_model(
  29. model_path,
  30. device=devices.device_codeformer,
  31. expected_architecture='CodeFormer',
  32. ).model
  33. raise ValueError("No codeformer model found")
  34. def get_device(self):
  35. return devices.device_codeformer
  36. def restore(self, np_image, w: float | None = None):
  37. if w is None:
  38. w = getattr(shared.opts, "code_former_weight", 0.5)
  39. def restore_face(cropped_face_t):
  40. assert self.net is not None
  41. return self.net(cropped_face_t, weight=w, adain=True)[0]
  42. return self.restore_with_helper(np_image, restore_face)
  43. def setup_model(dirname: str) -> None:
  44. global codeformer
  45. try:
  46. codeformer = FaceRestorerCodeFormer(dirname)
  47. shared.face_restorers.append(codeformer)
  48. except Exception:
  49. errors.report("Error setting up CodeFormer", exc_info=True)