|
@@ -24,6 +24,13 @@ def fix_torch_version():
|
|
|
torch.__long_version__ = torch.__version__
|
|
|
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
|
|
|
|
|
+def fix_pytorch_lightning():
|
|
|
+ # Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache
|
|
|
+ if 'pytorch_lightning.utilities.distributed' not in sys.modules:
|
|
|
+ import pytorch_lightning
|
|
|
+ # Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero
|
|
|
+ print(f"Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero")
|
|
|
+ sys.modules["pytorch_lightning.utilities.distributed"] = pytorch_lightning.utilities.rank_zero
|
|
|
|
|
|
def fix_asyncio_event_loop_policy():
|
|
|
"""
|