generation_parameters_copypaste.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. import base64
  2. import io
  3. import math
  4. import os
  5. import re
  6. from pathlib import Path
  7. import gradio as gr
  8. from modules.shared import script_path
  9. from modules import shared, ui_tempdir, script_callbacks
  10. import tempfile
  11. from PIL import Image
  12. re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
  13. re_param = re.compile(re_param_code)
  14. re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
  15. re_imagesize = re.compile(r"^(\d+)x(\d+)$")
  16. re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
  17. type_of_gr_update = type(gr.update())
  18. paste_fields = {}
  19. bind_list = []
  20. def reset():
  21. paste_fields.clear()
  22. bind_list.clear()
  23. def quote(text):
  24. if ',' not in str(text):
  25. return text
  26. text = str(text)
  27. text = text.replace('\\', '\\\\')
  28. text = text.replace('"', '\\"')
  29. return f'"{text}"'
  30. def image_from_url_text(filedata):
  31. if filedata is None:
  32. return None
  33. if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False):
  34. filedata = filedata[0]
  35. if type(filedata) == dict and filedata.get("is_file", False):
  36. filename = filedata["name"]
  37. is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
  38. assert is_in_right_dir, 'trying to open image file outside of allowed directories'
  39. return Image.open(filename)
  40. if type(filedata) == list:
  41. if len(filedata) == 0:
  42. return None
  43. filedata = filedata[0]
  44. if filedata.startswith("data:image/png;base64,"):
  45. filedata = filedata[len("data:image/png;base64,"):]
  46. filedata = base64.decodebytes(filedata.encode('utf-8'))
  47. image = Image.open(io.BytesIO(filedata))
  48. return image
  49. def add_paste_fields(tabname, init_img, fields):
  50. paste_fields[tabname] = {"init_img": init_img, "fields": fields}
  51. # backwards compatibility for existing extensions
  52. import modules.ui
  53. if tabname == 'txt2img':
  54. modules.ui.txt2img_paste_fields = fields
  55. elif tabname == 'img2img':
  56. modules.ui.img2img_paste_fields = fields
  57. def integrate_settings_paste_fields(component_dict):
  58. from modules import ui
  59. settings_map = {
  60. 'CLIP_stop_at_last_layers': 'Clip skip',
  61. 'inpainting_mask_weight': 'Conditional mask weight',
  62. 'sd_model_checkpoint': 'Model hash',
  63. 'eta_noise_seed_delta': 'ENSD',
  64. 'initial_noise_multiplier': 'Noise multiplier',
  65. }
  66. settings_paste_fields = [
  67. (component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
  68. for k, v in settings_map.items()
  69. ]
  70. for tabname, info in paste_fields.items():
  71. if info["fields"] is not None:
  72. info["fields"] += settings_paste_fields
  73. def create_buttons(tabs_list):
  74. buttons = {}
  75. for tab in tabs_list:
  76. buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
  77. return buttons
  78. #if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
  79. def bind_buttons(buttons, send_image, send_generate_info):
  80. bind_list.append([buttons, send_image, send_generate_info])
  81. def send_image_and_dimensions(x):
  82. if isinstance(x, Image.Image):
  83. img = x
  84. else:
  85. img = image_from_url_text(x)
  86. if shared.opts.send_size and isinstance(img, Image.Image):
  87. w = img.width
  88. h = img.height
  89. else:
  90. w = gr.update()
  91. h = gr.update()
  92. return img, w, h
  93. def run_bind():
  94. for buttons, source_image_component, send_generate_info in bind_list:
  95. for tab in buttons:
  96. button = buttons[tab]
  97. destination_image_component = paste_fields[tab]["init_img"]
  98. fields = paste_fields[tab]["fields"]
  99. destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
  100. destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)
  101. if source_image_component and destination_image_component:
  102. if isinstance(source_image_component, gr.Gallery):
  103. func = send_image_and_dimensions if destination_width_component else image_from_url_text
  104. jsfunc = "extract_image_from_gallery"
  105. else:
  106. func = send_image_and_dimensions if destination_width_component else lambda x: x
  107. jsfunc = None
  108. button.click(
  109. fn=func,
  110. _js=jsfunc,
  111. inputs=[source_image_component],
  112. outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
  113. )
  114. if send_generate_info and fields is not None:
  115. if send_generate_info in paste_fields:
  116. paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
  117. button.click(
  118. fn=lambda *x: x,
  119. inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
  120. outputs=[field for field, name in fields if name in paste_field_names],
  121. )
  122. else:
  123. connect_paste(button, fields, send_generate_info)
  124. button.click(
  125. fn=None,
  126. _js=f"switch_to_{tab}",
  127. inputs=None,
  128. outputs=None,
  129. )
  130. def find_hypernetwork_key(hypernet_name, hypernet_hash=None):
  131. """Determines the config parameter name to use for the hypernet based on the parameters in the infotext.
  132. Example: an infotext provides "Hypernet: ke-ta" and "Hypernet hash: 1234abcd". For the "Hypernet" config
  133. parameter this means there should be an entry that looks like "ke-ta-10000(1234abcd)" to set it to.
  134. If the infotext has no hash, then a hypernet with the same name will be selected instead.
  135. """
  136. hypernet_name = hypernet_name.lower()
  137. if hypernet_hash is not None:
  138. # Try to match the hash in the name
  139. for hypernet_key in shared.hypernetworks.keys():
  140. result = re_hypernet_hash.search(hypernet_key)
  141. if result is not None and result[1] == hypernet_hash:
  142. return hypernet_key
  143. else:
  144. # Fall back to a hypernet with the same name
  145. for hypernet_key in shared.hypernetworks.keys():
  146. if hypernet_key.lower().startswith(hypernet_name):
  147. return hypernet_key
  148. return None
  149. def restore_old_hires_fix_params(res):
  150. """for infotexts that specify old First pass size parameter, convert it into
  151. width, height, and hr scale"""
  152. firstpass_width = res.get('First pass size-1', None)
  153. firstpass_height = res.get('First pass size-2', None)
  154. if shared.opts.use_old_hires_fix_width_height:
  155. hires_width = int(res.get("Hires resize-1", 0))
  156. hires_height = int(res.get("Hires resize-2", 0))
  157. if hires_width and hires_height:
  158. res['Size-1'] = hires_width
  159. res['Size-2'] = hires_height
  160. return
  161. if firstpass_width is None or firstpass_height is None:
  162. return
  163. firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
  164. width = int(res.get("Size-1", 512))
  165. height = int(res.get("Size-2", 512))
  166. if firstpass_width == 0 or firstpass_height == 0:
  167. from modules import processing
  168. firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
  169. res['Size-1'] = firstpass_width
  170. res['Size-2'] = firstpass_height
  171. res['Hires resize-1'] = width
  172. res['Hires resize-2'] = height
  173. def parse_generation_parameters(x: str):
  174. """parses generation parameters string, the one you see in text field under the picture in UI:
  175. ```
  176. girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
  177. Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
  178. Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
  179. ```
  180. returns a dict with field values
  181. """
  182. res = {}
  183. prompt = ""
  184. negative_prompt = ""
  185. done_with_prompt = False
  186. *lines, lastline = x.strip().split("\n")
  187. if not re_params.match(lastline):
  188. lines.append(lastline)
  189. lastline = ''
  190. for i, line in enumerate(lines):
  191. line = line.strip()
  192. if line.startswith("Negative prompt:"):
  193. done_with_prompt = True
  194. line = line[16:].strip()
  195. if done_with_prompt:
  196. negative_prompt += ("" if negative_prompt == "" else "\n") + line
  197. else:
  198. prompt += ("" if prompt == "" else "\n") + line
  199. res["Prompt"] = prompt
  200. res["Negative prompt"] = negative_prompt
  201. for k, v in re_param.findall(lastline):
  202. m = re_imagesize.match(v)
  203. if m is not None:
  204. res[k+"-1"] = m.group(1)
  205. res[k+"-2"] = m.group(2)
  206. else:
  207. res[k] = v
  208. # Missing CLIP skip means it was set to 1 (the default)
  209. if "Clip skip" not in res:
  210. res["Clip skip"] = "1"
  211. hypernet = res.get("Hypernet", None)
  212. if hypernet is not None:
  213. res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
  214. if "Hires resize-1" not in res:
  215. res["Hires resize-1"] = 0
  216. res["Hires resize-2"] = 0
  217. restore_old_hires_fix_params(res)
  218. return res
  219. def connect_paste(button, paste_fields, input_comp, jsfunc=None):
  220. def paste_func(prompt):
  221. if not prompt and not shared.cmd_opts.hide_ui_dir_config:
  222. filename = os.path.join(script_path, "params.txt")
  223. if os.path.exists(filename):
  224. with open(filename, "r", encoding="utf8") as file:
  225. prompt = file.read()
  226. params = parse_generation_parameters(prompt)
  227. script_callbacks.infotext_pasted_callback(prompt, params)
  228. res = []
  229. for output, key in paste_fields:
  230. if callable(key):
  231. v = key(params)
  232. else:
  233. v = params.get(key, None)
  234. if v is None:
  235. res.append(gr.update())
  236. elif isinstance(v, type_of_gr_update):
  237. res.append(v)
  238. else:
  239. try:
  240. valtype = type(output.value)
  241. if valtype == bool and v == "False":
  242. val = False
  243. else:
  244. val = valtype(v)
  245. res.append(gr.update(value=val))
  246. except Exception:
  247. res.append(gr.update())
  248. return res
  249. button.click(
  250. fn=paste_func,
  251. _js=jsfunc,
  252. inputs=[input_comp],
  253. outputs=[x[0] for x in paste_fields],
  254. )