|
@@ -178,12 +178,11 @@ def manual_cast(target_dtype):
|
|
try:
|
|
try:
|
|
yield None
|
|
yield None
|
|
finally:
|
|
finally:
|
|
- if not applied:
|
|
|
|
- return
|
|
|
|
- for module_type in patch_module_list:
|
|
|
|
- if hasattr(module_type, "org_forward"):
|
|
|
|
- module_type.forward = module_type.org_forward
|
|
|
|
- delattr(module_type, "org_forward")
|
|
|
|
|
|
+ if applied:
|
|
|
|
+ for module_type in patch_module_list:
|
|
|
|
+ if hasattr(module_type, "org_forward"):
|
|
|
|
+ module_type.forward = module_type.org_forward
|
|
|
|
+ delattr(module_type, "org_forward")
|
|
|
|
|
|
|
|
|
|
def autocast(disable=False):
|
|
def autocast(disable=False):
|