|
@@ -12,9 +12,13 @@ def print_error_explanation(message):
|
|
print('=' * max_len, file=sys.stderr)
|
|
print('=' * max_len, file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
-def display(e: Exception, task):
|
|
|
|
|
|
+def display(e: Exception, task, *, full_traceback=False):
|
|
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
|
print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
|
|
- print(traceback.format_exc(), file=sys.stderr)
|
|
|
|
|
|
+ te = traceback.TracebackException.from_exception(e)
|
|
|
|
+ if full_traceback:
|
|
|
|
+ # include frames leading up to the try-catch block
|
|
|
|
+ te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)
|
|
|
|
+ print(*te.format(), sep="", file=sys.stderr)
|
|
|
|
|
|
message = str(e)
|
|
message = str(e)
|
|
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|
|
if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
|