sd_hijack_utils.py 1.4 KB

1234567891011121314151617181920212223242526272829303132
  1. import importlib
  2. class CondFunc:
  3. def __new__(cls, orig_func, sub_func, cond_func):
  4. self = super(CondFunc, cls).__new__(cls)
  5. if isinstance(orig_func, str):
  6. func_path = orig_func.split('.')
  7. for i in range(len(func_path)-1, -1, -1):
  8. try:
  9. resolved_obj = importlib.import_module('.'.join(func_path[:i]))
  10. break
  11. except ImportError:
  12. pass
  13. try:
  14. for attr_name in func_path[i:-1]:
  15. resolved_obj = getattr(resolved_obj, attr_name)
  16. orig_func = getattr(resolved_obj, func_path[-1])
  17. setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
  18. except AttributeError:
  19. print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack")
  20. pass
  21. self.__init__(orig_func, sub_func, cond_func)
  22. return lambda *args, **kwargs: self(*args, **kwargs)
  23. def __init__(self, orig_func, sub_func, cond_func):
  24. self.__orig_func = orig_func
  25. self.__sub_func = sub_func
  26. self.__cond_func = cond_func
  27. def __call__(self, *args, **kwargs):
  28. if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
  29. return self.__sub_func(self.__orig_func, *args, **kwargs)
  30. else:
  31. return self.__orig_func(*args, **kwargs)