shared.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. import argparse
  2. import datetime
  3. import json
  4. import os
  5. import sys
  6. import gradio as gr
  7. import tqdm
  8. import modules.artists
  9. import modules.interrogate
  10. import modules.memmon
  11. import modules.sd_models
  12. import modules.styles
  13. import modules.devices as devices
  14. from modules import sd_samplers
  15. from modules.paths import models_path, script_path, sd_path
  16. sd_model_file = os.path.join(script_path, 'model.ckpt')
  17. default_sd_model_file = sd_model_file
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
  20. parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
  21. parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
  22. parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
  23. parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
  24. parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
  25. parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
  26. parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
  27. parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
  28. parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
  29. parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
  30. parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
  31. parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
  32. parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
  33. parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
  34. parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
  35. parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
  36. parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
  37. parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN'))
  38. parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
  39. parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
  40. parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
  41. parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
  42. parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
  43. parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
  44. parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
  45. parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
  46. parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
  47. parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[])
  48. parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
  49. parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
  50. parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
  51. parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
  52. parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
  53. parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
  54. parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
  55. parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
  56. parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor")
  57. parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
  58. parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
  59. parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
  60. parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
  61. parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
  62. parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
  63. cmd_opts = parser.parse_args()
  64. devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
  65. (devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'])
  66. device = devices.device
  67. batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
  68. parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
  69. config_filename = cmd_opts.ui_settings_file
  70. def reload_hypernetworks():
  71. from modules.hypernetwork import hypernetwork
  72. hypernetworks.clear()
  73. hypernetworks.update(hypernetwork.load_hypernetworks(cmd_opts.hypernetwork_dir))
  74. hypernetworks = {}
  75. hypernetwork = None
  76. class State:
  77. interrupted = False
  78. job = ""
  79. job_no = 0
  80. job_count = 0
  81. job_timestamp = '0'
  82. sampling_step = 0
  83. sampling_steps = 0
  84. current_latent = None
  85. current_image = None
  86. current_image_sampling_step = 0
  87. textinfo = None
  88. def interrupt(self):
  89. self.interrupted = True
  90. def nextjob(self):
  91. self.job_no += 1
  92. self.sampling_step = 0
  93. self.current_image_sampling_step = 0
  94. def get_job_timestamp(self):
  95. return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
  96. state = State()
  97. artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
  98. styles_filename = cmd_opts.styles_file
  99. prompt_styles = modules.styles.StyleDatabase(styles_filename)
  100. interrogator = modules.interrogate.InterrogateModels("interrogate")
  101. face_restorers = []
  102. # This was moved to webui.py with the other model "setup" calls.
  103. # modules.sd_models.list_models()
  104. def realesrgan_models_names():
  105. import modules.realesrgan_model
  106. return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
  107. class OptionInfo:
  108. def __init__(self, default=None, label="", component=None, component_args=None, onchange=None):
  109. self.default = default
  110. self.label = label
  111. self.component = component
  112. self.component_args = component_args
  113. self.onchange = onchange
  114. self.section = None
  115. def options_section(section_identifer, options_dict):
  116. for k, v in options_dict.items():
  117. v.section = section_identifer
  118. return options_dict
  119. hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}
  120. options_templates = {}
  121. options_templates.update(options_section(('saving-images', "Saving images/grids"), {
  122. "samples_save": OptionInfo(True, "Always save all generated images"),
  123. "samples_format": OptionInfo('png', 'File format for images'),
  124. "samples_filename_pattern": OptionInfo("", "Images filename pattern"),
  125. "grid_save": OptionInfo(True, "Always save all generated image grids"),
  126. "grid_format": OptionInfo('png', 'File format for grids'),
  127. "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
  128. "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
  129. "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
  130. "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
  131. "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
  132. "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
  133. "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
  134. "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
  135. "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
  136. "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
  137. }))
  138. options_templates.update(options_section(('saving-paths', "Paths for saving"), {
  139. "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
  140. "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
  141. "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
  142. "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
  143. "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
  144. "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
  145. "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
  146. "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
  147. }))
  148. options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
  149. "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
  150. "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"),
  151. "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"),
  152. "directories_filename_pattern": OptionInfo("", "Directory name pattern"),
  153. "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
  154. }))
  155. options_templates.update(options_section(('upscaling', "Upscaling"), {
  156. "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
  157. "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
  158. "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
  159. "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
  160. "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
  161. "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
  162. "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
  163. }))
  164. options_templates.update(options_section(('face-restoration', "Face restoration"), {
  165. "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
  166. "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
  167. "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
  168. }))
  169. options_templates.update(options_section(('system', "System"), {
  170. "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
  171. "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
  172. "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
  173. }))
  174. options_templates.update(options_section(('sd', "Stable Diffusion"), {
  175. "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
  176. "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
  177. "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
  178. "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
  179. "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
  180. "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
  181. "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
  182. "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
  183. "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
  184. "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
  185. "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
  186. }))
  187. options_templates.update(options_section(('interrogate', "Interrogate Options"), {
  188. "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
  189. "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
  190. "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
  191. "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
  192. "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
  193. "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
  194. }))
  195. options_templates.update(options_section(('ui', "User interface"), {
  196. "show_progressbar": OptionInfo(True, "Show progressbar"),
  197. "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
  198. "return_grid": OptionInfo(True, "Show grid in results for web"),
  199. "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
  200. "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
  201. "font": OptionInfo("", "Font for image grids that have text"),
  202. "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
  203. "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
  204. "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
  205. }))
  206. options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
  207. "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}),
  208. "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
  209. "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
  210. "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
  211. 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
  212. 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
  213. 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
  214. }))
  215. class Options:
  216. data = None
  217. data_labels = options_templates
  218. typemap = {int: float}
  219. def __init__(self):
  220. self.data = {k: v.default for k, v in self.data_labels.items()}
  221. def __setattr__(self, key, value):
  222. if self.data is not None:
  223. if key in self.data:
  224. self.data[key] = value
  225. return super(Options, self).__setattr__(key, value)
  226. def __getattr__(self, item):
  227. if self.data is not None:
  228. if item in self.data:
  229. return self.data[item]
  230. if item in self.data_labels:
  231. return self.data_labels[item].default
  232. return super(Options, self).__getattribute__(item)
  233. def save(self, filename):
  234. with open(filename, "w", encoding="utf8") as file:
  235. json.dump(self.data, file)
  236. def same_type(self, x, y):
  237. if x is None or y is None:
  238. return True
  239. type_x = self.typemap.get(type(x), type(x))
  240. type_y = self.typemap.get(type(y), type(y))
  241. return type_x == type_y
  242. def load(self, filename):
  243. with open(filename, "r", encoding="utf8") as file:
  244. self.data = json.load(file)
  245. bad_settings = 0
  246. for k, v in self.data.items():
  247. info = self.data_labels.get(k, None)
  248. if info is not None and not self.same_type(info.default, v):
  249. print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
  250. bad_settings += 1
  251. if bad_settings > 0:
  252. print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
  253. def onchange(self, key, func):
  254. item = self.data_labels.get(key)
  255. item.onchange = func
  256. def dumpjson(self):
  257. d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
  258. return json.dumps(d)
  259. opts = Options()
  260. if os.path.exists(config_filename):
  261. opts.load(config_filename)
  262. sd_upscalers = []
  263. sd_model = None
  264. progress_print_out = sys.stdout
  265. class TotalTQDM:
  266. def __init__(self):
  267. self._tqdm = None
  268. def reset(self):
  269. self._tqdm = tqdm.tqdm(
  270. desc="Total progress",
  271. total=state.job_count * state.sampling_steps,
  272. position=1,
  273. file=progress_print_out
  274. )
  275. def update(self):
  276. if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
  277. return
  278. if self._tqdm is None:
  279. self.reset()
  280. self._tqdm.update()
  281. def updateTotal(self, new_total):
  282. if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
  283. return
  284. if self._tqdm is None:
  285. self.reset()
  286. self._tqdm.total=new_total
  287. def clear(self):
  288. if self._tqdm is not None:
  289. self._tqdm.close()
  290. self._tqdm = None
  291. total_tqdm = TotalTQDM()
  292. mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
  293. mem_mon.start()