torch_utils.py 459 B

1234567891011121314151617
  1. from __future__ import annotations
  2. import torch.nn
  3. def get_param(model) -> torch.nn.Parameter:
  4. """
  5. Find the first parameter in a model or module.
  6. """
  7. if hasattr(model, "model") and hasattr(model.model, "parameters"):
  8. # Unpeel a model descriptor to get at the actual Torch module.
  9. model = model.model
  10. for param in model.parameters():
  11. return param
  12. raise ValueError(f"No parameters found in model {model!r}")