|
@@ -1,20 +1,11 @@
|
|
|
import re
|
|
|
from collections import namedtuple
|
|
|
import torch
|
|
|
+from lark import Lark, Transformer, Visitor
|
|
|
+import functools
|
|
|
|
|
|
import modules.shared as shared
|
|
|
|
|
|
-re_prompt = re.compile(r'''
|
|
|
-(.*?)
|
|
|
-\[
|
|
|
- ([^]:]+):
|
|
|
- (?:([^]:]*):)?
|
|
|
- ([0-9]*\.?[0-9]+)
|
|
|
-]
|
|
|
-|
|
|
|
-(.+)
|
|
|
-''', re.X)
|
|
|
-
|
|
|
# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
|
|
|
# will be represented with prompt_schedule like this (assuming steps=100):
|
|
|
# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
|
|
@@ -25,61 +16,57 @@ re_prompt = re.compile(r'''
|
|
|
|
|
|
|
|
|
def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|
|
- res = []
|
|
|
- cache = {}
|
|
|
-
|
|
|
- for prompt in prompts:
|
|
|
- prompt_schedule: list[list[str | int]] = [[steps, ""]]
|
|
|
-
|
|
|
- cached = cache.get(prompt, None)
|
|
|
- if cached is not None:
|
|
|
- res.append(cached)
|
|
|
- continue
|
|
|
-
|
|
|
- for m in re_prompt.finditer(prompt):
|
|
|
- plaintext = m.group(1) if m.group(5) is None else m.group(5)
|
|
|
- concept_from = m.group(2)
|
|
|
- concept_to = m.group(3)
|
|
|
- if concept_to is None:
|
|
|
- concept_to = concept_from
|
|
|
- concept_from = ""
|
|
|
- swap_position = float(m.group(4)) if m.group(4) is not None else None
|
|
|
-
|
|
|
- if swap_position is not None:
|
|
|
- if swap_position < 1:
|
|
|
- swap_position = swap_position * steps
|
|
|
- swap_position = int(min(swap_position, steps))
|
|
|
-
|
|
|
- swap_index = None
|
|
|
- found_exact_index = False
|
|
|
- for i in range(len(prompt_schedule)):
|
|
|
- end_step = prompt_schedule[i][0]
|
|
|
- prompt_schedule[i][1] += plaintext
|
|
|
-
|
|
|
- if swap_position is not None and swap_index is None:
|
|
|
- if swap_position == end_step:
|
|
|
- swap_index = i
|
|
|
- found_exact_index = True
|
|
|
-
|
|
|
- if swap_position < end_step:
|
|
|
- swap_index = i
|
|
|
-
|
|
|
- if swap_index is not None:
|
|
|
- if not found_exact_index:
|
|
|
- prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
|
|
|
-
|
|
|
- for i in range(len(prompt_schedule)):
|
|
|
- end_step = prompt_schedule[i][0]
|
|
|
- must_replace = swap_position < end_step
|
|
|
-
|
|
|
- prompt_schedule[i][1] += concept_to if must_replace else concept_from
|
|
|
-
|
|
|
- res.append(prompt_schedule)
|
|
|
- cache[prompt] = prompt_schedule
|
|
|
- #for t in prompt_schedule:
|
|
|
- # print(t)
|
|
|
-
|
|
|
- return res
|
|
|
+ grammar = r"""
|
|
|
+ start: prompt
|
|
|
+ prompt: (emphasized | scheduled | weighted | plain)*
|
|
|
+ !emphasized: "(" prompt ")"
|
|
|
+ | "(" prompt ":" prompt ")"
|
|
|
+ | "[" prompt "]"
|
|
|
+ scheduled: "[" (prompt ":")? prompt ":" NUMBER "]"
|
|
|
+ !weighted: "{" weighted_item ("|" weighted_item)* "}"
|
|
|
+ !weighted_item: prompt (":" prompt)?
|
|
|
+ plain: /([^\\\[\](){}:|]|\\.)+/
|
|
|
+ %import common.SIGNED_NUMBER -> NUMBER
|
|
|
+ """
|
|
|
+ parser = Lark(grammar, parser='lalr')
|
|
|
+ def collect_steps(steps, tree):
|
|
|
+ l = [steps]
|
|
|
+ class CollectSteps(Visitor):
|
|
|
+ def scheduled(self, tree):
|
|
|
+ tree.children[-1] = float(tree.children[-1])
|
|
|
+ if tree.children[-1] < 1:
|
|
|
+ tree.children[-1] *= steps
|
|
|
+ tree.children[-1] = min(steps, int(tree.children[-1]))
|
|
|
+ l.append(tree.children[-1])
|
|
|
+ CollectSteps().visit(tree)
|
|
|
+ return sorted(set(l))
|
|
|
+ def at_step(step, tree):
|
|
|
+ class AtStep(Transformer):
|
|
|
+ def scheduled(self, args):
|
|
|
+ if len(args) == 2:
|
|
|
+ before, after, when = (), *args
|
|
|
+ else:
|
|
|
+ before, after, when = args
|
|
|
+ yield before if step <= when else after
|
|
|
+ def start(self, args):
|
|
|
+ def flatten(x):
|
|
|
+ if type(x) == str:
|
|
|
+ yield x
|
|
|
+ else:
|
|
|
+ for gen in x:
|
|
|
+ yield from flatten(gen)
|
|
|
+ return ''.join(flatten(args[0]))
|
|
|
+ def plain(self, args):
|
|
|
+ yield args[0].value
|
|
|
+ def __default__(self, data, children, meta):
|
|
|
+ for child in children:
|
|
|
+ yield from child
|
|
|
+ return AtStep().transform(tree)
|
|
|
+ @functools.cache
|
|
|
+ def get_schedule(prompt):
|
|
|
+ tree = parser.parse(prompt)
|
|
|
+ return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
|
|
|
+ return [get_schedule(prompt) for prompt in prompts]
|
|
|
|
|
|
|
|
|
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|