ui_checkpoint_merger.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import gradio as gr
  2. from modules import sd_models, sd_vae, errors, extras, call_queue
  3. from modules.ui_components import FormRow
  4. from modules.ui_common import create_refresh_button
  5. def update_interp_description(value):
  6. interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
  7. interp_descriptions = {
  8. "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
  9. "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
  10. "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
  11. }
  12. return interp_descriptions[value]
  13. def modelmerger(*args):
  14. try:
  15. results = extras.run_modelmerger(*args)
  16. except Exception as e:
  17. errors.report("Error loading/saving model file", exc_info=True)
  18. sd_models.list_models() # to remove the potentially missing models from the list
  19. return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
  20. return results
  21. class UiCheckpointMerger:
  22. def __init__(self):
  23. with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
  24. with gr.Row(equal_height=False):
  25. with gr.Column(variant='compact'):
  26. self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
  27. with FormRow(elem_id="modelmerger_models"):
  28. self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
  29. create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
  30. self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
  31. create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
  32. self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
  33. create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
  34. self.custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
  35. self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
  36. self.interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
  37. self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description])
  38. with FormRow():
  39. self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
  40. self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
  41. with FormRow():
  42. with gr.Column():
  43. self.config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
  44. with gr.Column():
  45. with FormRow():
  46. self.bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
  47. create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
  48. with FormRow():
  49. self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
  50. with gr.Accordion("Metadata", open=False) as metadata_editor:
  51. with FormRow():
  52. self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")
  53. self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")
  54. self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")
  55. self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")
  56. self.read_metadata = gr.Button("Read metadata from selected checkpoints")
  57. with FormRow():
  58. self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
  59. with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
  60. with gr.Group(elem_id="modelmerger_results_panel"):
  61. self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
  62. self.metadata_editor = metadata_editor
  63. self.blocks = modelmerger_interface
  64. def setup_ui(self, dummy_component, sd_model_checkpoint_component):
  65. self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)
  66. self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])
  67. self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])
  68. self.modelmerger_merge.click(
  69. fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
  70. _js='modelmerger',
  71. inputs=[
  72. dummy_component,
  73. self.primary_model_name,
  74. self.secondary_model_name,
  75. self.tertiary_model_name,
  76. self.interp_method,
  77. self.interp_amount,
  78. self.save_as_half,
  79. self.custom_name,
  80. self.checkpoint_format,
  81. self.config_source,
  82. self.bake_in_vae,
  83. self.discard_weights,
  84. self.save_metadata,
  85. self.add_merge_recipe,
  86. self.copy_metadata_fields,
  87. self.metadata_json,
  88. ],
  89. outputs=[
  90. self.primary_model_name,
  91. self.secondary_model_name,
  92. self.tertiary_model_name,
  93. sd_model_checkpoint_component,
  94. self.modelmerger_result,
  95. ]
  96. )
  97. # Required as a workaround for change() event not triggering when loading values from ui-config.json
  98. self.interp_description.value = update_interp_description(self.interp_method.value)