styles.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. from __future__ import annotations
  2. from pathlib import Path
  3. from modules import errors
  4. import csv
  5. import os
  6. import typing
  7. import shutil
  8. class PromptStyle(typing.NamedTuple):
  9. name: str
  10. prompt: str | None
  11. negative_prompt: str | None
  12. path: str | None = None
  13. def merge_prompts(style_prompt: str, prompt: str) -> str:
  14. if "{prompt}" in style_prompt:
  15. res = style_prompt.replace("{prompt}", prompt)
  16. else:
  17. parts = filter(None, (prompt.strip(), style_prompt.strip()))
  18. res = ", ".join(parts)
  19. return res
  20. def apply_styles_to_prompt(prompt, styles):
  21. for style in styles:
  22. prompt = merge_prompts(style, prompt)
  23. return prompt
  24. def extract_style_text_from_prompt(style_text, prompt):
  25. """This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.
  26. extract_style_text_from_prompt("masterpiece", "1girl, art by greg, masterpiece") outputs (True, "1girl, art by greg")
  27. extract_style_text_from_prompt("masterpiece, {prompt}", "masterpiece, 1girl, art by greg") outputs (True, "1girl, art by greg")
  28. extract_style_text_from_prompt("masterpiece, {prompt}", "exquisite, 1girl, art by greg") outputs (False, "exquisite, 1girl, art by greg")
  29. """
  30. stripped_prompt = prompt.strip()
  31. stripped_style_text = style_text.strip()
  32. if "{prompt}" in stripped_style_text:
  33. left, _, right = stripped_style_text.partition("{prompt}")
  34. if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
  35. prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
  36. return True, prompt
  37. else:
  38. if stripped_prompt.endswith(stripped_style_text):
  39. prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
  40. if prompt.endswith(', '):
  41. prompt = prompt[:-2]
  42. return True, prompt
  43. return False, prompt
  44. def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
  45. """
  46. Takes a style and compares it to the prompt and negative prompt. If the style
  47. matches, returns True plus the prompt and negative prompt with the style text
  48. removed. Otherwise, returns False with the original prompt and negative prompt.
  49. """
  50. if not style.prompt and not style.negative_prompt:
  51. return False, prompt, negative_prompt
  52. match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
  53. if not match_positive:
  54. return False, prompt, negative_prompt
  55. match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
  56. if not match_negative:
  57. return False, prompt, negative_prompt
  58. return True, extracted_positive, extracted_negative
  59. class StyleDatabase:
  60. def __init__(self, paths: list[str | Path]):
  61. self.no_style = PromptStyle("None", "", "", None)
  62. self.styles = {}
  63. self.paths = paths
  64. self.all_styles_files: list[Path] = []
  65. folder, file = os.path.split(self.paths[0])
  66. if '*' in file or '?' in file:
  67. # if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
  68. self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
  69. self.paths.insert(0, self.default_path)
  70. else:
  71. self.default_path = Path(self.paths[0])
  72. self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
  73. self.reload()
  74. def reload(self):
  75. """
  76. Clears the style database and reloads the styles from the CSV file(s)
  77. matching the path used to initialize the database.
  78. """
  79. self.styles.clear()
  80. # scans for all styles files
  81. all_styles_files = []
  82. for pattern in self.paths:
  83. folder, file = os.path.split(pattern)
  84. if '*' in file or '?' in file:
  85. found_files = Path(folder).glob(file)
  86. [all_styles_files.append(file) for file in found_files]
  87. else:
  88. # if os.path.exists(pattern):
  89. all_styles_files.append(Path(pattern))
  90. # Remove any duplicate entries
  91. seen = set()
  92. self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]
  93. for styles_file in self.all_styles_files:
  94. if len(all_styles_files) > 1:
  95. # add divider when more than styles file
  96. # '---------------- STYLES ----------------'
  97. divider = f' {styles_file.stem.upper()} '.center(40, '-')
  98. self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
  99. if styles_file.is_file():
  100. self.load_from_csv(styles_file)
  101. def load_from_csv(self, path: str | Path):
  102. try:
  103. with open(path, "r", encoding="utf-8-sig", newline="") as file:
  104. reader = csv.DictReader(file, skipinitialspace=True)
  105. for row in reader:
  106. # Ignore empty rows or rows starting with a comment
  107. if not row or row["name"].startswith("#"):
  108. continue
  109. # Support loading old CSV format with "name, text"-columns
  110. prompt = row["prompt"] if "prompt" in row else row["text"]
  111. negative_prompt = row.get("negative_prompt", "")
  112. # Add style to database
  113. self.styles[row["name"]] = PromptStyle(
  114. row["name"], prompt, negative_prompt, str(path)
  115. )
  116. except Exception:
  117. errors.report(f'Error loading styles from {path}: ', exc_info=True)
  118. def get_style_paths(self) -> set:
  119. """Returns a set of all distinct paths of files that styles are loaded from."""
  120. # Update any styles without a path to the default path
  121. for style in list(self.styles.values()):
  122. if not style.path:
  123. self.styles[style.name] = style._replace(path=str(self.default_path))
  124. # Create a list of all distinct paths, including the default path
  125. style_paths = set()
  126. style_paths.add(str(self.default_path))
  127. for _, style in self.styles.items():
  128. if style.path:
  129. style_paths.add(style.path)
  130. # Remove any paths for styles that are just list dividers
  131. style_paths.discard("do_not_save")
  132. return style_paths
  133. def get_style_prompts(self, styles):
  134. return [self.styles.get(x, self.no_style).prompt for x in styles]
  135. def get_negative_style_prompts(self, styles):
  136. return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
  137. def apply_styles_to_prompt(self, prompt, styles):
  138. return apply_styles_to_prompt(
  139. prompt, [self.styles.get(x, self.no_style).prompt for x in styles]
  140. )
  141. def apply_negative_styles_to_prompt(self, prompt, styles):
  142. return apply_styles_to_prompt(
  143. prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]
  144. )
  145. def save_styles(self, path: str = None) -> None:
  146. # The path argument is deprecated, but kept for backwards compatibility
  147. style_paths = self.get_style_paths()
  148. csv_names = [os.path.split(path)[1].lower() for path in style_paths]
  149. for style_path in style_paths:
  150. # Always keep a backup file around
  151. if os.path.exists(style_path):
  152. shutil.copy(style_path, f"{style_path}.bak")
  153. # Write the styles to the CSV file
  154. with open(style_path, "w", encoding="utf-8-sig", newline="") as file:
  155. writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
  156. writer.writeheader()
  157. for style in (s for s in self.styles.values() if s.path == style_path):
  158. # Skip style list dividers, e.g. "STYLES.CSV"
  159. if style.name.lower().strip("# ") in csv_names:
  160. continue
  161. # Write style fields, ignoring the path field
  162. writer.writerow(
  163. {k: v for k, v in style._asdict().items() if k != "path"}
  164. )
  165. def extract_styles_from_prompt(self, prompt, negative_prompt):
  166. extracted = []
  167. applicable_styles = list(self.styles.values())
  168. while True:
  169. found_style = None
  170. for style in applicable_styles:
  171. is_match, new_prompt, new_neg_prompt = extract_original_prompts(
  172. style, prompt, negative_prompt
  173. )
  174. if is_match:
  175. found_style = style
  176. prompt = new_prompt
  177. negative_prompt = new_neg_prompt
  178. break
  179. if not found_style:
  180. break
  181. applicable_styles.remove(found_style)
  182. extracted.append(found_style.name)
  183. return list(reversed(extracted)), prompt, negative_prompt