|
@@ -139,23 +139,31 @@ def load_upscalers():
|
|
|
|
|
|
|
|
|
def load_spandrel_model(
|
|
|
- path: str,
|
|
|
+ path: str | os.PathLike,
|
|
|
*,
|
|
|
device: str | torch.device | None,
|
|
|
- half: bool = False,
|
|
|
+ prefer_half: bool = False,
|
|
|
dtype: str | torch.dtype | None = None,
|
|
|
expected_architecture: str | None = None,
|
|
|
) -> spandrel.ModelDescriptor:
|
|
|
import spandrel
|
|
|
- model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
|
|
|
+ model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))
|
|
|
if expected_architecture and model_descriptor.architecture != expected_architecture:
|
|
|
logger.warning(
|
|
|
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
|
|
|
)
|
|
|
- if half:
|
|
|
- model_descriptor.model.half()
|
|
|
+ half = False
|
|
|
+ if prefer_half:
|
|
|
+ if model_descriptor.supports_half:
|
|
|
+ model_descriptor.model.half()
|
|
|
+ half = True
|
|
|
+ else:
|
|
|
+ logger.info("Model %s does not support half precision, ignoring --half", path)
|
|
|
if dtype:
|
|
|
model_descriptor.model.to(dtype=dtype)
|
|
|
model_descriptor.model.eval()
|
|
|
- logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
|
|
|
+ logger.debug(
|
|
|
+ "Loaded %s from %s (device=%s, half=%s, dtype=%s)",
|
|
|
+ model_descriptor, path, device, half, dtype,
|
|
|
+ )
|
|
|
return model_descriptor
|