prompt_parser.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import re
  2. from collections import namedtuple
  3. import torch
  4. from lark import Lark, Transformer, Visitor
  5. import functools
  6. import modules.shared as shared
  7. # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
  8. # will be represented with prompt_schedule like this (assuming steps=100):
  9. # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
  10. # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']
  11. # [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']
  12. # [75, 'fantasy landscape with a lake and an oak in background masterful']
  13. # [100, 'fantasy landscape with a lake and a christmas tree in background masterful']
  14. def get_learned_conditioning_prompt_schedules(prompts, steps):
  15. grammar = r"""
  16. start: prompt
  17. prompt: (emphasized | scheduled | weighted | plain)*
  18. !emphasized: "(" prompt ")"
  19. | "(" prompt ":" prompt ")"
  20. | "[" prompt "]"
  21. scheduled: "[" (prompt ":")? prompt ":" NUMBER "]"
  22. !weighted: "{" weighted_item ("|" weighted_item)* "}"
  23. !weighted_item: prompt (":" prompt)?
  24. plain: /([^\\\[\](){}:|]|\\.)+/
  25. %import common.SIGNED_NUMBER -> NUMBER
  26. """
  27. parser = Lark(grammar, parser='lalr')
  28. def collect_steps(steps, tree):
  29. l = [steps]
  30. class CollectSteps(Visitor):
  31. def scheduled(self, tree):
  32. tree.children[-1] = float(tree.children[-1])
  33. if tree.children[-1] < 1:
  34. tree.children[-1] *= steps
  35. tree.children[-1] = min(steps, int(tree.children[-1]))
  36. l.append(tree.children[-1])
  37. CollectSteps().visit(tree)
  38. return sorted(set(l))
  39. def at_step(step, tree):
  40. class AtStep(Transformer):
  41. def scheduled(self, args):
  42. if len(args) == 2:
  43. before, after, when = (), *args
  44. else:
  45. before, after, when = args
  46. yield before if step <= when else after
  47. def start(self, args):
  48. def flatten(x):
  49. if type(x) == str:
  50. yield x
  51. else:
  52. for gen in x:
  53. yield from flatten(gen)
  54. return ''.join(flatten(args[0]))
  55. def plain(self, args):
  56. yield args[0].value
  57. def __default__(self, data, children, meta):
  58. for child in children:
  59. yield from child
  60. return AtStep().transform(tree)
  61. def get_schedule(prompt):
  62. tree = parser.parse(prompt)
  63. return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
  64. promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}
  65. return [promptdict[prompt] for prompt in prompts]
  66. ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
  67. ScheduledPromptBatch = namedtuple("ScheduledPromptBatch", ["shape", "schedules"])
  68. def get_learned_conditioning(prompts, steps):
  69. res = []
  70. prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps)
  71. cache = {}
  72. for prompt, prompt_schedule in zip(prompts, prompt_schedules):
  73. cached = cache.get(prompt, None)
  74. if cached is not None:
  75. res.append(cached)
  76. continue
  77. texts = [x[1] for x in prompt_schedule]
  78. conds = shared.sd_model.get_learned_conditioning(texts)
  79. cond_schedule = []
  80. for i, (end_at_step, text) in enumerate(prompt_schedule):
  81. cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
  82. cache[prompt] = cond_schedule
  83. res.append(cond_schedule)
  84. return ScheduledPromptBatch((len(prompts),) + res[0][0].cond.shape, res)
  85. def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
  86. res = torch.zeros(c.shape, device=shared.device, dtype=next(shared.sd_model.parameters()).dtype)
  87. for i, cond_schedule in enumerate(c.schedules):
  88. target_index = 0
  89. for curret_index, (end_at, cond) in enumerate(cond_schedule):
  90. if current_step <= end_at:
  91. target_index = curret_index
  92. break
  93. res[i] = cond_schedule[target_index].cond
  94. return res
  95. re_attention = re.compile(r"""
  96. \\\(|
  97. \\\)|
  98. \\\[|
  99. \\]|
  100. \\\\|
  101. \\|
  102. \(|
  103. \[|
  104. :([+-]?[.\d]+)\)|
  105. \)|
  106. ]|
  107. [^\\()\[\]:]+|
  108. :
  109. """, re.X)
  110. def parse_prompt_attention(text):
  111. """
  112. Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
  113. Accepted tokens are:
  114. (abc) - increases attention to abc by a multiplier of 1.1
  115. (abc:3.12) - increases attention to abc by a multiplier of 3.12
  116. [abc] - decreases attention to abc by a multiplier of 1.1
  117. \( - literal character '('
  118. \[ - literal character '['
  119. \) - literal character ')'
  120. \] - literal character ']'
  121. \\ - literal character '\'
  122. anything else - just text
  123. Example:
  124. 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
  125. produces:
  126. [
  127. ['a ', 1.0],
  128. ['house', 1.5730000000000004],
  129. [' ', 1.1],
  130. ['on', 1.0],
  131. [' a ', 1.1],
  132. ['hill', 0.55],
  133. [', sun, ', 1.1],
  134. ['sky', 1.4641000000000006],
  135. ['.', 1.1]
  136. ]
  137. """
  138. res = []
  139. round_brackets = []
  140. square_brackets = []
  141. round_bracket_multiplier = 1.1
  142. square_bracket_multiplier = 1 / 1.1
  143. def multiply_range(start_position, multiplier):
  144. for p in range(start_position, len(res)):
  145. res[p][1] *= multiplier
  146. for m in re_attention.finditer(text):
  147. text = m.group(0)
  148. weight = m.group(1)
  149. if text.startswith('\\'):
  150. res.append([text[1:], 1.0])
  151. elif text == '(':
  152. round_brackets.append(len(res))
  153. elif text == '[':
  154. square_brackets.append(len(res))
  155. elif weight is not None and len(round_brackets) > 0:
  156. multiply_range(round_brackets.pop(), float(weight))
  157. elif text == ')' and len(round_brackets) > 0:
  158. multiply_range(round_brackets.pop(), round_bracket_multiplier)
  159. elif text == ']' and len(square_brackets) > 0:
  160. multiply_range(square_brackets.pop(), square_bracket_multiplier)
  161. else:
  162. res.append([text, 1.0])
  163. for pos in round_brackets:
  164. multiply_range(pos, round_bracket_multiplier)
  165. for pos in square_brackets:
  166. multiply_range(pos, square_bracket_multiplier)
  167. if len(res) == 0:
  168. res = [["", 1.0]]
  169. return res