gfpgan_model.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import torch
  5. from modules import (
  6. devices,
  7. errors,
  8. face_restoration,
  9. face_restoration_utils,
  10. modelloader,
  11. shared,
  12. )
  13. logger = logging.getLogger(__name__)
  14. model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
  15. model_download_name = "GFPGANv1.4.pth"
  16. gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
  17. class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
  18. def name(self):
  19. return "GFPGAN"
  20. def get_device(self):
  21. return devices.device_gfpgan
  22. def load_net(self) -> torch.Module:
  23. for model_path in modelloader.load_models(
  24. model_path=self.model_path,
  25. model_url=model_url,
  26. command_path=self.model_path,
  27. download_name=model_download_name,
  28. ext_filter=['.pth'],
  29. ):
  30. if 'GFPGAN' in os.path.basename(model_path):
  31. model = modelloader.load_spandrel_model(
  32. model_path,
  33. device=self.get_device(),
  34. expected_architecture='GFPGAN',
  35. ).model
  36. model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
  37. return model
  38. raise ValueError("No GFPGAN model found")
  39. def restore(self, np_image):
  40. def restore_face(cropped_face_t):
  41. assert self.net is not None
  42. return self.net(cropped_face_t, return_rgb=False)[0]
  43. return self.restore_with_helper(np_image, restore_face)
  44. def gfpgan_fix_faces(np_image):
  45. if gfpgan_face_restorer:
  46. return gfpgan_face_restorer.restore(np_image)
  47. logger.warning("GFPGAN face restorer not set up")
  48. return np_image
  49. def setup_model(dirname: str) -> None:
  50. global gfpgan_face_restorer
  51. try:
  52. face_restoration_utils.patch_facexlib(dirname)
  53. gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
  54. shared.face_restorers.append(gfpgan_face_restorer)
  55. except Exception:
  56. errors.report("Error setting up GFPGAN", exc_info=True)