gfpgan_model.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import torch
  5. from modules import (
  6. devices,
  7. errors,
  8. face_restoration,
  9. face_restoration_utils,
  10. modelloader,
  11. shared,
  12. )
  13. logger = logging.getLogger(__name__)
  14. model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
  15. model_download_name = "GFPGANv1.4.pth"
  16. gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
  17. class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
  18. def name(self):
  19. return "GFPGAN"
  20. def get_device(self):
  21. return devices.device_gfpgan
  22. def load_net(self) -> torch.Module:
  23. for model_path in modelloader.load_models(
  24. model_path=self.model_path,
  25. model_url=model_url,
  26. command_path=self.model_path,
  27. download_name=model_download_name,
  28. ext_filter=['.pth'],
  29. ):
  30. if 'GFPGAN' in os.path.basename(model_path):
  31. return modelloader.load_spandrel_model(
  32. model_path,
  33. device=self.get_device(),
  34. expected_architecture='GFPGAN',
  35. ).model
  36. raise ValueError("No GFPGAN model found")
  37. def restore(self, np_image):
  38. def restore_face(cropped_face_t):
  39. assert self.net is not None
  40. return self.net(cropped_face_t, return_rgb=False)[0]
  41. return self.restore_with_helper(np_image, restore_face)
  42. def gfpgan_fix_faces(np_image):
  43. if gfpgan_face_restorer:
  44. return gfpgan_face_restorer.restore(np_image)
  45. logger.warning("GFPGAN face restorer not set up")
  46. return np_image
  47. def setup_model(dirname: str) -> None:
  48. global gfpgan_face_restorer
  49. try:
  50. face_restoration_utils.patch_facexlib(dirname)
  51. gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
  52. shared.face_restorers.append(gfpgan_face_restorer)
  53. except Exception:
  54. errors.report("Error setting up GFPGAN", exc_info=True)