瀏覽代碼

Account when lines are mismatched

hentailord85ez 2 年之前
父節點
當前提交
80f3cf2bb2
共有 1 個文件被更改,包括 11 次插入1 次删除
  1. 11 1
      modules/sd_hijack.py

+ 11 - 1
modules/sd_hijack.py

@@ -321,7 +321,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
                         fixes.append(fix[1])
                         fixes.append(fix[1])
                 self.hijack.fixes.append(fixes)
                 self.hijack.fixes.append(fixes)
             
             
-            z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers])
+            tokens = []
+            multipliers = []
+            for i in range(len(remade_batch_tokens)):
+                if len(remade_batch_tokens[i]) > 0:
+                    tokens.append(remade_batch_tokens[i][:75])
+                    multipliers.append(batch_multipliers[i][:75])
+                else:
+                    tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
+                    multipliers.append([1.0] * 75)
+
+            z1 = self.process_tokens(tokens, multipliers)
             z = z1 if z is None else torch.cat((z, z1), axis=-2)
             z = z1 if z is None else torch.cat((z, z1), axis=-2)
             
             
             remade_batch_tokens = rem_tokens
             remade_batch_tokens = rem_tokens