test_torch_utils.py 473 B

12345678910111213141516171819
  1. import types
  2. import pytest
  3. import torch
  4. from modules import torch_utils
  5. @pytest.mark.parametrize("wrapped", [True, False])
  6. def test_get_param(wrapped):
  7. mod = torch.nn.Linear(1, 1)
  8. cpu = torch.device("cpu")
  9. mod.to(dtype=torch.float16, device=cpu)
  10. if wrapped:
  11. # more or less how spandrel wraps a thing
  12. mod = types.SimpleNamespace(model=mod)
  13. p = torch_utils.get_param(mod)
  14. assert p.dtype == torch.float16
  15. assert p.device == cpu