infotext_utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. from __future__ import annotations
  2. import base64
  3. import io
  4. import json
  5. import os
  6. import re
  7. import sys
  8. import gradio as gr
  9. from modules.paths import data_path
  10. from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images, prompt_parser, errors
  11. from PIL import Image
  12. sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name
  13. re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
  14. re_param = re.compile(re_param_code)
  15. re_imagesize = re.compile(r"^(\d+)x(\d+)$")
  16. re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
  17. type_of_gr_update = type(gr.update())
  18. class ParamBinding:
  19. def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
  20. self.paste_button = paste_button
  21. self.tabname = tabname
  22. self.source_text_component = source_text_component
  23. self.source_image_component = source_image_component
  24. self.source_tabname = source_tabname
  25. self.override_settings_component = override_settings_component
  26. self.paste_field_names = paste_field_names or []
  27. class PasteField(tuple):
  28. def __new__(cls, component, target, *, api=None):
  29. return super().__new__(cls, (component, target))
  30. def __init__(self, component, target, *, api=None):
  31. super().__init__()
  32. self.api = api
  33. self.component = component
  34. self.label = target if isinstance(target, str) else None
  35. self.function = target if callable(target) else None
  36. paste_fields: dict[str, dict] = {}
  37. registered_param_bindings: list[ParamBinding] = []
  38. def reset():
  39. paste_fields.clear()
  40. registered_param_bindings.clear()
  41. def quote(text):
  42. if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
  43. return text
  44. return json.dumps(text, ensure_ascii=False)
  45. def unquote(text):
  46. if len(text) == 0 or text[0] != '"' or text[-1] != '"':
  47. return text
  48. try:
  49. return json.loads(text)
  50. except Exception:
  51. return text
  52. def image_from_url_text(filedata):
  53. if filedata is None:
  54. return None
  55. if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
  56. filedata = filedata[0]
  57. if type(filedata) == dict and filedata.get("is_file", False):
  58. filename = filedata["name"]
  59. is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
  60. assert is_in_right_dir, 'trying to open image file outside of allowed directories'
  61. filename = filename.rsplit('?', 1)[0]
  62. return images.read(filename)
  63. if type(filedata) == list:
  64. if len(filedata) == 0:
  65. return None
  66. filedata = filedata[0]
  67. if filedata.startswith("data:image/png;base64,"):
  68. filedata = filedata[len("data:image/png;base64,"):]
  69. filedata = base64.decodebytes(filedata.encode('utf-8'))
  70. image = images.read(io.BytesIO(filedata))
  71. return image
  72. def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
  73. if fields:
  74. for i in range(len(fields)):
  75. if not isinstance(fields[i], PasteField):
  76. fields[i] = PasteField(*fields[i])
  77. paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
  78. # backwards compatibility for existing extensions
  79. import modules.ui
  80. if tabname == 'txt2img':
  81. modules.ui.txt2img_paste_fields = fields
  82. elif tabname == 'img2img':
  83. modules.ui.img2img_paste_fields = fields
  84. def create_buttons(tabs_list):
  85. buttons = {}
  86. for tab in tabs_list:
  87. buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
  88. return buttons
  89. def bind_buttons(buttons, send_image, send_generate_info):
  90. """old function for backwards compatibility; do not use this, use register_paste_params_button"""
  91. for tabname, button in buttons.items():
  92. source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
  93. source_tabname = send_generate_info if isinstance(send_generate_info, str) else None
  94. register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))
  95. def register_paste_params_button(binding: ParamBinding):
  96. registered_param_bindings.append(binding)
  97. def connect_paste_params_buttons():
  98. for binding in registered_param_bindings:
  99. destination_image_component = paste_fields[binding.tabname]["init_img"]
  100. fields = paste_fields[binding.tabname]["fields"]
  101. override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
  102. destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
  103. destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
  104. if binding.source_image_component and destination_image_component:
  105. need_send_dementions = destination_width_component and binding.tabname != 'inpaint'
  106. if isinstance(binding.source_image_component, gr.Gallery):
  107. func = send_image_and_dimensions if need_send_dementions else image_from_url_text
  108. jsfunc = "extract_image_from_gallery"
  109. else:
  110. func = send_image_and_dimensions if need_send_dementions else lambda x: x
  111. jsfunc = None
  112. binding.paste_button.click(
  113. fn=func,
  114. _js=jsfunc,
  115. inputs=[binding.source_image_component],
  116. outputs=[destination_image_component, destination_width_component, destination_height_component] if need_send_dementions else [destination_image_component],
  117. show_progress=False,
  118. )
  119. if binding.source_text_component is not None and fields is not None:
  120. connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
  121. if binding.source_tabname is not None and fields is not None:
  122. paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
  123. binding.paste_button.click(
  124. fn=lambda *x: x,
  125. inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
  126. outputs=[field for field, name in fields if name in paste_field_names],
  127. show_progress=False,
  128. )
  129. binding.paste_button.click(
  130. fn=None,
  131. _js=f"switch_to_{binding.tabname}",
  132. inputs=None,
  133. outputs=None,
  134. show_progress=False,
  135. )
  136. def send_image_and_dimensions(x):
  137. if isinstance(x, Image.Image):
  138. img = x
  139. else:
  140. img = image_from_url_text(x)
  141. if shared.opts.send_size and isinstance(img, Image.Image):
  142. w = img.width
  143. h = img.height
  144. else:
  145. w = gr.update()
  146. h = gr.update()
  147. return img, w, h
  148. def restore_old_hires_fix_params(res):
  149. """for infotexts that specify old First pass size parameter, convert it into
  150. width, height, and hr scale"""
  151. firstpass_width = res.get('First pass size-1', None)
  152. firstpass_height = res.get('First pass size-2', None)
  153. if shared.opts.use_old_hires_fix_width_height:
  154. hires_width = int(res.get("Hires resize-1", 0))
  155. hires_height = int(res.get("Hires resize-2", 0))
  156. if hires_width and hires_height:
  157. res['Size-1'] = hires_width
  158. res['Size-2'] = hires_height
  159. return
  160. if firstpass_width is None or firstpass_height is None:
  161. return
  162. firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
  163. width = int(res.get("Size-1", 512))
  164. height = int(res.get("Size-2", 512))
  165. if firstpass_width == 0 or firstpass_height == 0:
  166. firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
  167. res['Size-1'] = firstpass_width
  168. res['Size-2'] = firstpass_height
  169. res['Hires resize-1'] = width
  170. res['Hires resize-2'] = height
  171. def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
  172. """parses generation parameters string, the one you see in text field under the picture in UI:
  173. ```
  174. girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
  175. Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
  176. Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
  177. ```
  178. returns a dict with field values
  179. """
  180. if skip_fields is None:
  181. skip_fields = shared.opts.infotext_skip_pasting
  182. res = {}
  183. prompt = ""
  184. negative_prompt = ""
  185. done_with_prompt = False
  186. *lines, lastline = x.strip().split("\n")
  187. if len(re_param.findall(lastline)) < 3:
  188. lines.append(lastline)
  189. lastline = ''
  190. for line in lines:
  191. line = line.strip()
  192. if line.startswith("Negative prompt:"):
  193. done_with_prompt = True
  194. line = line[16:].strip()
  195. if done_with_prompt:
  196. negative_prompt += ("" if negative_prompt == "" else "\n") + line
  197. else:
  198. prompt += ("" if prompt == "" else "\n") + line
  199. for k, v in re_param.findall(lastline):
  200. try:
  201. if v[0] == '"' and v[-1] == '"':
  202. v = unquote(v)
  203. m = re_imagesize.match(v)
  204. if m is not None:
  205. res[f"{k}-1"] = m.group(1)
  206. res[f"{k}-2"] = m.group(2)
  207. else:
  208. res[k] = v
  209. except Exception:
  210. print(f"Error parsing \"{k}: {v}\"")
  211. # Extract styles from prompt
  212. if shared.opts.infotext_styles != "Ignore":
  213. found_styles, prompt_no_styles, negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)
  214. same_hr_styles = True
  215. if ("Hires prompt" in res or "Hires negative prompt" in res) and (infotext_ver > infotext_versions.v180_hr_styles if (infotext_ver := infotext_versions.parse_version(res.get("Version"))) else True):
  216. hr_prompt, hr_negative_prompt = res.get("Hires prompt", prompt), res.get("Hires negative prompt", negative_prompt)
  217. hr_found_styles, hr_prompt_no_styles, hr_negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(hr_prompt, hr_negative_prompt)
  218. if same_hr_styles := found_styles == hr_found_styles:
  219. res["Hires prompt"] = '' if hr_prompt_no_styles == prompt_no_styles else hr_prompt_no_styles
  220. res['Hires negative prompt'] = '' if hr_negative_prompt_no_styles == negative_prompt_no_styles else hr_negative_prompt_no_styles
  221. if same_hr_styles:
  222. prompt, negative_prompt = prompt_no_styles, negative_prompt_no_styles
  223. if (shared.opts.infotext_styles == "Apply if any" and found_styles) or shared.opts.infotext_styles == "Apply":
  224. res['Styles array'] = found_styles
  225. res["Prompt"] = prompt
  226. res["Negative prompt"] = negative_prompt
  227. # Missing CLIP skip means it was set to 1 (the default)
  228. if "Clip skip" not in res:
  229. res["Clip skip"] = "1"
  230. hypernet = res.get("Hypernet", None)
  231. if hypernet is not None:
  232. res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
  233. if "Hires resize-1" not in res:
  234. res["Hires resize-1"] = 0
  235. res["Hires resize-2"] = 0
  236. if "Hires sampler" not in res:
  237. res["Hires sampler"] = "Use same sampler"
  238. if "Hires schedule type" not in res:
  239. res["Hires schedule type"] = "Use same scheduler"
  240. if "Hires checkpoint" not in res:
  241. res["Hires checkpoint"] = "Use same checkpoint"
  242. if "Hires prompt" not in res:
  243. res["Hires prompt"] = ""
  244. if "Hires negative prompt" not in res:
  245. res["Hires negative prompt"] = ""
  246. if "Mask mode" not in res:
  247. res["Mask mode"] = "Inpaint masked"
  248. if "Masked content" not in res:
  249. res["Masked content"] = 'original'
  250. if "Inpaint area" not in res:
  251. res["Inpaint area"] = "Whole picture"
  252. if "Masked area padding" not in res:
  253. res["Masked area padding"] = 32
  254. restore_old_hires_fix_params(res)
  255. # Missing RNG means the default was set, which is GPU RNG
  256. if "RNG" not in res:
  257. res["RNG"] = "GPU"
  258. if "Schedule type" not in res:
  259. res["Schedule type"] = "Automatic"
  260. if "Schedule max sigma" not in res:
  261. res["Schedule max sigma"] = 0
  262. if "Schedule min sigma" not in res:
  263. res["Schedule min sigma"] = 0
  264. if "Schedule rho" not in res:
  265. res["Schedule rho"] = 0
  266. if "VAE Encoder" not in res:
  267. res["VAE Encoder"] = "Full"
  268. if "VAE Decoder" not in res:
  269. res["VAE Decoder"] = "Full"
  270. if "FP8 weight" not in res:
  271. res["FP8 weight"] = "Disable"
  272. if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
  273. res["Cache FP16 weight for LoRA"] = False
  274. prompt_attention = prompt_parser.parse_prompt_attention(prompt)
  275. prompt_attention += prompt_parser.parse_prompt_attention(negative_prompt)
  276. prompt_uses_emphasis = len(prompt_attention) != len([p for p in prompt_attention if p[1] == 1.0 or p[0] == 'BREAK'])
  277. if "Emphasis" not in res and prompt_uses_emphasis:
  278. res["Emphasis"] = "Original"
  279. if "Refiner switch by sampling steps" not in res:
  280. res["Refiner switch by sampling steps"] = False
  281. infotext_versions.backcompat(res)
  282. for key in skip_fields:
  283. res.pop(key, None)
  284. return res
  285. infotext_to_setting_name_mapping = [
  286. ]
  287. """Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
  288. Example content:
  289. infotext_to_setting_name_mapping = [
  290. ('Conditional mask weight', 'inpainting_mask_weight'),
  291. ('Model hash', 'sd_model_checkpoint'),
  292. ('ENSD', 'eta_noise_seed_delta'),
  293. ('Schedule type', 'k_sched_type'),
  294. ]
  295. """
  296. def create_override_settings_dict(text_pairs):
  297. """creates processing's override_settings parameters from gradio's multiselect
  298. Example input:
  299. ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']
  300. Example output:
  301. {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
  302. """
  303. res = {}
  304. params = {}
  305. for pair in text_pairs:
  306. k, v = pair.split(":", maxsplit=1)
  307. params[k] = v.strip()
  308. mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
  309. for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
  310. value = params.get(param_name, None)
  311. if value is None:
  312. continue
  313. res[setting_name] = shared.opts.cast_value(setting_name, value)
  314. return res
  315. def get_override_settings(params, *, skip_fields=None):
  316. """Returns a list of settings overrides from the infotext parameters dictionary.
  317. This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
  318. a list of tuples containing the parameter name, setting name, and new value cast to correct type.
  319. It checks for conditions before adding an override:
  320. - ignores settings that match the current value
  321. - ignores parameter keys present in skip_fields argument.
  322. Example input:
  323. {"Clip skip": "2"}
  324. Example output:
  325. [("Clip skip", "CLIP_stop_at_last_layers", 2)]
  326. """
  327. res = []
  328. mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
  329. for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
  330. if param_name in (skip_fields or {}):
  331. continue
  332. v = params.get(param_name, None)
  333. if v is None:
  334. continue
  335. if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
  336. continue
  337. v = shared.opts.cast_value(setting_name, v)
  338. current_value = getattr(shared.opts, setting_name, None)
  339. if v == current_value:
  340. continue
  341. res.append((param_name, setting_name, v))
  342. return res
  343. def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
  344. def paste_func(prompt):
  345. if not prompt and not shared.cmd_opts.hide_ui_dir_config and not shared.cmd_opts.no_prompt_history:
  346. filename = os.path.join(data_path, "params.txt")
  347. try:
  348. with open(filename, "r", encoding="utf8") as file:
  349. prompt = file.read()
  350. except OSError:
  351. pass
  352. params = parse_generation_parameters(prompt)
  353. script_callbacks.infotext_pasted_callback(prompt, params)
  354. res = []
  355. for output, key in paste_fields:
  356. if callable(key):
  357. try:
  358. v = key(params)
  359. except Exception:
  360. errors.report(f"Error executing {key}", exc_info=True)
  361. v = None
  362. else:
  363. v = params.get(key, None)
  364. if v is None:
  365. res.append(gr.update())
  366. elif isinstance(v, type_of_gr_update):
  367. res.append(v)
  368. else:
  369. try:
  370. valtype = type(output.value)
  371. if valtype == bool and v == "False":
  372. val = False
  373. elif valtype == int:
  374. val = float(v)
  375. else:
  376. val = valtype(v)
  377. res.append(gr.update(value=val))
  378. except Exception:
  379. res.append(gr.update())
  380. return res
  381. if override_settings_component is not None:
  382. already_handled_fields = {key: 1 for _, key in paste_fields}
  383. def paste_settings(params):
  384. vals = get_override_settings(params, skip_fields=already_handled_fields)
  385. vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
  386. return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
  387. paste_fields = paste_fields + [(override_settings_component, paste_settings)]
  388. button.click(
  389. fn=paste_func,
  390. inputs=[input_comp],
  391. outputs=[x[0] for x in paste_fields],
  392. show_progress=False,
  393. )
  394. button.click(
  395. fn=None,
  396. _js=f"recalculate_prompts_{tabname}",
  397. inputs=[],
  398. outputs=[],
  399. show_progress=False,
  400. )