infotext_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. from __future__ import annotations
  2. import base64
  3. import io
  4. import json
  5. import os
  6. import re
  7. import sys
  8. import gradio as gr
  9. from modules.paths import data_path
  10. from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images, prompt_parser
  11. from PIL import Image
  12. sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
  13. re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
  14. re_param = re.compile(re_param_code)
  15. re_imagesize = re.compile(r"^(\d+)x(\d+)$")
  16. re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
  17. type_of_gr_update = type(gr.update())
  18. class ParamBinding:
  19. def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
  20. self.paste_button = paste_button
  21. self.tabname = tabname
  22. self.source_text_component = source_text_component
  23. self.source_image_component = source_image_component
  24. self.source_tabname = source_tabname
  25. self.override_settings_component = override_settings_component
  26. self.paste_field_names = paste_field_names or []
  27. class PasteField(tuple):
  28. def __new__(cls, component, target, *, api=None):
  29. return super().__new__(cls, (component, target))
  30. def __init__(self, component, target, *, api=None):
  31. super().__init__()
  32. self.api = api
  33. self.component = component
  34. self.label = target if isinstance(target, str) else None
  35. self.function = target if callable(target) else None
  36. paste_fields: dict[str, dict] = {}
  37. registered_param_bindings: list[ParamBinding] = []
  38. def reset():
  39. paste_fields.clear()
  40. registered_param_bindings.clear()
  41. def quote(text):
  42. if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
  43. return text
  44. return json.dumps(text, ensure_ascii=False)
  45. def unquote(text):
  46. if len(text) == 0 or text[0] != '"' or text[-1] != '"':
  47. return text
  48. try:
  49. return json.loads(text)
  50. except Exception:
  51. return text
  52. def image_from_url_text(filedata):
  53. if filedata is None:
  54. return None
  55. if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
  56. filedata = filedata[0]
  57. if type(filedata) == dict and filedata.get("is_file", False):
  58. filename = filedata["name"]
  59. is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
  60. assert is_in_right_dir, 'trying to open image file outside of allowed directories'
  61. filename = filename.rsplit('?', 1)[0]
  62. return images.read(filename)
  63. if type(filedata) == list:
  64. if len(filedata) == 0:
  65. return None
  66. filedata = filedata[0]
  67. if filedata.startswith("data:image/png;base64,"):
  68. filedata = filedata[len("data:image/png;base64,"):]
  69. filedata = base64.decodebytes(filedata.encode('utf-8'))
  70. image = images.read(io.BytesIO(filedata))
  71. return image
  72. def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
  73. if fields:
  74. for i in range(len(fields)):
  75. if not isinstance(fields[i], PasteField):
  76. fields[i] = PasteField(*fields[i])
  77. paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
  78. # backwards compatibility for existing extensions
  79. import modules.ui
  80. if tabname == 'txt2img':
  81. modules.ui.txt2img_paste_fields = fields
  82. elif tabname == 'img2img':
  83. modules.ui.img2img_paste_fields = fields
  84. def create_buttons(tabs_list):
  85. buttons = {}
  86. for tab in tabs_list:
  87. buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
  88. return buttons
  89. def bind_buttons(buttons, send_image, send_generate_info):
  90. """old function for backwards compatibility; do not use this, use register_paste_params_button"""
  91. for tabname, button in buttons.items():
  92. source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
  93. source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
  94. register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
  95. def register_paste_params_button(binding: ParamBinding):
  96. registered_param_bindings.append(binding)
  97. def connect_paste_params_buttons():
  98. for binding in registered_param_bindings:
  99. destination_image_component = paste_fields[binding.tabname]["init_img"]
  100. fields = paste_fields[binding.tabname]["fields"]
  101. override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
  102. destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
  103. destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
  104. if binding.source_image_component and destination_image_component:
  105. if isinstance(binding.source_image_component, gr.Gallery):
  106. func = send_image_and_dimensions if destination_width_component else image_from_url_text
  107. jsfunc = "extract_image_from_gallery"
  108. else:
  109. func = send_image_and_dimensions if destination_width_component else lambda x: x
  110. jsfunc = None
  111. binding.paste_button.click(
  112. fn=func,
  113. _js=jsfunc,
  114. inputs=[binding.source_image_component],
  115. outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
  116. show_progress=False,
  117. )
  118. if binding.source_text_component is not None and fields is not None:
  119. connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
  120. if binding.source_tabname is not None and fields is not None:
  121. paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
  122. binding.paste_button.click(
  123. fn=lambda *x: x,
  124. inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
  125. outputs=[field for field, name in fields if name in paste_field_names],
  126. show_progress=False,
  127. )
  128. binding.paste_button.click(
  129. fn=None,
  130. _js=f"switch_to_{binding.tabname}",
  131. inputs=None,
  132. outputs=None,
  133. show_progress=False,
  134. )
  135. def send_image_and_dimensions(x):
  136. if isinstance(x, Image.Image):
  137. img = x
  138. else:
  139. img = image_from_url_text(x)
  140. if shared.opts.send_size and isinstance(img, Image.Image):
  141. w = img.width
  142. h = img.height
  143. else:
  144. w = gr.update()
  145. h = gr.update()
  146. return img, w, h
  147. def restore_old_hires_fix_params(res):
  148. """for infotexts that specify old First pass size parameter, convert it into
  149. width, height, and hr scale"""
  150. firstpass_width = res.get('First pass size-1', None)
  151. firstpass_height = res.get('First pass size-2', None)
  152. if shared.opts.use_old_hires_fix_width_height:
  153. hires_width = int(res.get("Hires resize-1", 0))
  154. hires_height = int(res.get("Hires resize-2", 0))
  155. if hires_width and hires_height:
  156. res['Size-1'] = hires_width
  157. res['Size-2'] = hires_height
  158. return
  159. if firstpass_width is None or firstpass_height is None:
  160. return
  161. firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
  162. width = int(res.get("Size-1", 512))
  163. height = int(res.get("Size-2", 512))
  164. if firstpass_width == 0 or firstpass_height == 0:
  165. firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
  166. res['Size-1'] = firstpass_width
  167. res['Size-2'] = firstpass_height
  168. res['Hires resize-1'] = width
  169. res['Hires resize-2'] = height
  170. def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
  171. """parses generation parameters string, the one you see in text field under the picture in UI:
  172. ```
  173. girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
  174. Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
  175. Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
  176. ```
  177. returns a dict with field values
  178. """
  179. if skip_fields is None:
  180. skip_fields = shared.opts.infotext_skip_pasting
  181. res = {}
  182. prompt = ""
  183. negative_prompt = ""
  184. done_with_prompt = False
  185. *lines, lastline = x.strip().split("\n")
  186. if len(re_param.findall(lastline)) < 3:
  187. lines.append(lastline)
  188. lastline = ''
  189. for line in lines:
  190. line = line.strip()
  191. if line.startswith("Negative prompt:"):
  192. done_with_prompt = True
  193. line = line[16:].strip()
  194. if done_with_prompt:
  195. negative_prompt += ("" if negative_prompt == "" else "\n") + line
  196. else:
  197. prompt += ("" if prompt == "" else "\n") + line
  198. if shared.opts.infotext_styles != "Ignore":
  199. found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
  200. if shared.opts.infotext_styles == "Apply":
  201. res["Styles array"] = found_styles
  202. elif shared.opts.infotext_styles == "Apply if any" and found_styles:
  203. res["Styles array"] = found_styles
  204. res["Prompt"] = prompt
  205. res["Negative prompt"] = negative_prompt
  206. for k, v in re_param.findall(lastline):
  207. try:
  208. if v[0] == '"' and v[-1] == '"':
  209. v = unquote(v)
  210. m = re_imagesize.match(v)
  211. if m is not None:
  212. res[f"{k}-1"] = m.group(1)
  213. res[f"{k}-2"] = m.group(2)
  214. else:
  215. res[k] = v
  216. except Exception:
  217. print(f"Error parsing \"{k}: {v}\"")
  218. # Missing CLIP skip means it was set to 1 (the default)
  219. if "Clip skip" not in res:
  220. res["Clip skip"] = "1"
  221. hypernet = res.get("Hypernet", None)
  222. if hypernet is not None:
  223. res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
  224. if "Hires resize-1" not in res:
  225. res["Hires resize-1"] = 0
  226. res["Hires resize-2"] = 0
  227. if "Hires sampler" not in res:
  228. res["Hires sampler"] = "Use same sampler"
  229. if "Hires checkpoint" not in res:
  230. res["Hires checkpoint"] = "Use same checkpoint"
  231. if "Hires prompt" not in res:
  232. res["Hires prompt"] = ""
  233. if "Hires negative prompt" not in res:
  234. res["Hires negative prompt"] = ""
  235. if "Mask mode" not in res:
  236. res["Mask mode"] = "Inpaint masked"
  237. if "Masked content" not in res:
  238. res["Masked content"] = 'original'
  239. if "Inpaint area" not in res:
  240. res["Inpaint area"] = "Whole picture"
  241. if "Masked area padding" not in res:
  242. res["Masked area padding"] = 32
  243. restore_old_hires_fix_params(res)
  244. # Missing RNG means the default was set, which is GPU RNG
  245. if "RNG" not in res:
  246. res["RNG"] = "GPU"
  247. if "Schedule type" not in res:
  248. res["Schedule type"] = "Automatic"
  249. if "Schedule max sigma" not in res:
  250. res["Schedule max sigma"] = 0
  251. if "Schedule min sigma" not in res:
  252. res["Schedule min sigma"] = 0
  253. if "Schedule rho" not in res:
  254. res["Schedule rho"] = 0
  255. if "VAE Encoder" not in res:
  256. res["VAE Encoder"] = "Full"
  257. if "VAE Decoder" not in res:
  258. res["VAE Decoder"] = "Full"
  259. if "FP8 weight" not in res:
  260. res["FP8 weight"] = "Disable"
  261. if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
  262. res["Cache FP16 weight for LoRA"] = False
  263. prompt_attention = prompt_parser.parse_prompt_attention(prompt)
  264. prompt_attention += prompt_parser.parse_prompt_attention(negative_prompt)
  265. prompt_uses_emphasis = len(prompt_attention) != len([p for p in prompt_attention if p[1] == 1.0 or p[0] == 'BREAK'])
  266. if "Emphasis" not in res and prompt_uses_emphasis:
  267. res["Emphasis"] = "Original"
  268. if "Refiner switch by sampling steps" not in res:
  269. res["Refiner switch by sampling steps"] = False
  270. infotext_versions.backcompat(res)
  271. for key in skip_fields:
  272. res.pop(key, None)
  273. return res
  274. infotext_to_setting_name_mapping = [
  275. ]
  276. """Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
  277. Example content:
  278. infotext_to_setting_name_mapping = [
  279. ('Conditional mask weight', 'inpainting_mask_weight'),
  280. ('Model hash', 'sd_model_checkpoint'),
  281. ('ENSD', 'eta_noise_seed_delta'),
  282. ('Schedule type', 'k_sched_type'),
  283. ]
  284. """
  285. def create_override_settings_dict(text_pairs):
  286. """creates processing's override_settings parameters from gradio's multiselect
  287. Example input:
  288. ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
  289. Example output:
  290. {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
  291. """
  292. res = {}
  293. params = {}
  294. for pair in text_pairs:
  295. k, v = pair.split(":", maxsplit=1)
  296. params[k] = v.strip()
  297. mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
  298. for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
  299. value = params.get(param_name, None)
  300. if value is None:
  301. continue
  302. res[setting_name] = shared.opts.cast_value(setting_name, value)
  303. return res
  304. def get_override_settings(params, *, skip_fields=None):
  305. """Returns a list of settings overrides from the infotext parameters dictionary.
  306. This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
  307. a list of tuples containing the parameter name, setting name, and new value cast to correct type.
  308. It checks for conditions before adding an override:
  309. - ignores settings that match the current value
  310. - ignores parameter keys present in skip_fields argument.
  311. Example input:
  312. {"Clip skip": "2"}
  313. Example output:
  314. [("Clip skip", "CLIP_stop_at_last_layers", 2)]
  315. """
  316. res = []
  317. mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
  318. for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
  319. if param_name in (skip_fields or {}):
  320. continue
  321. v = params.get(param_name, None)
  322. if v is None:
  323. continue
  324. if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
  325. continue
  326. v = shared.opts.cast_value(setting_name, v)
  327. current_value = getattr(shared.opts, setting_name, None)
  328. if v == current_value:
  329. continue
  330. res.append((param_name, setting_name, v))
  331. return res
  332. def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
  333. def paste_func(prompt):
  334. if not prompt and not shared.cmd_opts.hide_ui_dir_config:
  335. filename = os.path.join(data_path, "params.txt")
  336. try:
  337. with open(filename, "r", encoding="utf8") as file:
  338. prompt = file.read()
  339. except OSError:
  340. pass
  341. params = parse_generation_parameters(prompt)
  342. script_callbacks.infotext_pasted_callback(prompt, params)
  343. res = []
  344. for output, key in paste_fields:
  345. if callable(key):
  346. v = key(params)
  347. else:
  348. v = params.get(key, None)
  349. if v is None:
  350. res.append(gr.update())
  351. elif isinstance(v, type_of_gr_update):
  352. res.append(v)
  353. else:
  354. try:
  355. valtype = type(output.value)
  356. if valtype == bool and v == "False":
  357. val = False
  358. elif valtype == int:
  359. val = float(v)
  360. else:
  361. val = valtype(v)
  362. res.append(gr.update(value=val))
  363. except Exception:
  364. res.append(gr.update())
  365. return res
  366. if override_settings_component is not None:
  367. already_handled_fields = {key: 1 for _, key in paste_fields}
  368. def paste_settings(params):
  369. vals = get_override_settings(params, skip_fields=already_handled_fields)
  370. vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
  371. return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
  372. paste_fields = paste_fields + [(override_settings_component, paste_settings)]
  373. button.click(
  374. fn=paste_func,
  375. inputs=[input_comp],
  376. outputs=[x[0] for x in paste_fields],
  377. show_progress=False,
  378. )
  379. button.click(
  380. fn=None,
  381. _js=f"recalculate_prompts_{tabname}",
  382. inputs=[],
  383. outputs=[],
  384. show_progress=False,
  385. )