|
@@ -106,8 +106,8 @@ if has_xpu:
|
|
|
try:
|
|
|
# torch.Generator supports "xpu" device since 2.1
|
|
|
torch.Generator("xpu")
|
|
|
- except:
|
|
|
- # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1)
|
|
|
+ except RuntimeError:
|
|
|
+ # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
|
|
|
CondFunc('torch.Generator',
|
|
|
lambda orig_func, device=None: torch.xpu.Generator(device),
|
|
|
lambda orig_func, device=None: is_xpu_device(device))
|