sd_hijack_utils.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import importlib
  2. always_true_func = lambda *args, **kwargs: True
  3. class CondFunc:
  4. def __new__(cls, orig_func, sub_func, cond_func=always_true_func):
  5. self = super(CondFunc, cls).__new__(cls)
  6. if isinstance(orig_func, str):
  7. func_path = orig_func.split('.')
  8. for i in range(len(func_path)-1, -1, -1):
  9. try:
  10. resolved_obj = importlib.import_module('.'.join(func_path[:i]))
  11. break
  12. except ImportError:
  13. pass
  14. try:
  15. for attr_name in func_path[i:-1]:
  16. resolved_obj = getattr(resolved_obj, attr_name)
  17. orig_func = getattr(resolved_obj, func_path[-1])
  18. setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
  19. except AttributeError:
  20. print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
  21. pass
  22. self.__init__(orig_func, sub_func, cond_func)
  23. return lambda *args, **kwargs: self(*args, **kwargs)
  24. def __init__(self, orig_func, sub_func, cond_func):
  25. self.__orig_func = orig_func
  26. self.__sub_func = sub_func
  27. self.__cond_func = cond_func
  28. def __call__(self, *args, **kwargs):
  29. if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
  30. return self.__sub_func(self.__orig_func, *args, **kwargs)
  31. else:
  32. return self.__orig_func(*args, **kwargs)