Преглед изворни кода

fix AND broken for long prompts

AUTOMATIC пре 2 година
родитељ
комит
7001bffe02
1 измењених фајлова са 9 додато и 0 уклоњено
  1. 9 0
      modules/prompt_parser.py

+ 9 - 0
modules/prompt_parser.py

@@ -239,6 +239,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
 
         conds_list.append(conds_for_batch)
 
+    # if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
+    # and won't be able to torch.stack them. So this fixes that.
+    token_count = max([x.shape[0] for x in tensors])
+    for i in range(len(tensors)):
+        if tensors[i].shape[0] != token_count:
+            last_vector = tensors[i][-1:]
+            last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
+            tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
+
     return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)