patches.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from collections import defaultdict
  2. def patch(key, obj, field, replacement):
  3. """Replaces a function in a module or a class.
  4. Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
  5. If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
  6. Arguments:
  7. key: identifying information for who is doing the replacement. You can use __name__.
  8. obj: the module or the class
  9. field: name of the function as a string
  10. replacement: the new function
  11. Returns:
  12. the original function
  13. """
  14. patch_key = (obj, field)
  15. if patch_key in originals[key]:
  16. raise RuntimeError(f"patch for {field} is already applied")
  17. original_func = getattr(obj, field)
  18. originals[key][patch_key] = original_func
  19. setattr(obj, field, replacement)
  20. return original_func
  21. def undo(key, obj, field):
  22. """Undoes the peplacement by the patch().
  23. If the function is not replaced, raises an exception.
  24. Arguments:
  25. key: identifying information for who is doing the replacement. You can use __name__.
  26. obj: the module or the class
  27. field: name of the function as a string
  28. Returns:
  29. Always None
  30. """
  31. patch_key = (obj, field)
  32. if patch_key not in originals[key]:
  33. raise RuntimeError(f"there is no patch for {field} to undo")
  34. original_func = originals[key].pop(patch_key)
  35. setattr(obj, field, original_func)
  36. return None
  37. def original(key, obj, field):
  38. """Returns the original function for the patch created by the patch() function"""
  39. patch_key = (obj, field)
  40. return originals[key].get(patch_key, None)
  41. originals = defaultdict(dict)