Explorar o código

repair unload sd checkpoint button

AUTOMATIC1111 hai 1 ano
pai
achega
282903bb67
Modificáronse 3 ficheiros con 23 adicións e 25 borrados
  1. 5 6
      modules/api/api.py
  2. 1 12
      modules/sd_models.py
  3. 17 7
      modules/ui_settings.py

+ 5 - 6
modules/api/api.py

@@ -17,15 +17,14 @@ from fastapi.encoders import jsonable_encoder
 from secrets import compare_digest
 
 import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
 from modules.api import models
 from modules.shared import opts
 from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
 from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
 from modules.textual_inversion.preprocess import preprocess
 from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
-from PIL import PngImagePlugin,Image
-from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
+from PIL import PngImagePlugin, Image
 from modules.sd_models_config import find_checkpoint_config_near_filename
 from modules.realesrgan_model import get_realesrgan_models
 from modules import devices
@@ -541,12 +540,12 @@ class Api:
         return {}
 
     def unloadapi(self):
-        unload_model_weights()
+        sd_models.unload_model_weights()
 
         return {}
 
     def reloadapi(self):
-        reload_model_weights()
+        sd_models.send_model_to_device(shared.sd_model)
 
         return {}
 
@@ -566,7 +565,7 @@ class Api:
 
     def set_config(self, req: dict[str, Any]):
         checkpoint_name = req.get("sd_model_checkpoint", None)
-        if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
+        if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
             raise RuntimeError(f"model {checkpoint_name!r} not found")
 
         for k, v in req.items():

+ 1 - 12
modules/sd_models.py

@@ -1,7 +1,6 @@
 import collections
 import os.path
 import sys
-import gc
 import threading
 
 import torch
@@ -798,17 +797,7 @@ def reload_model_weights(sd_model=None, info=None):
 
 
 def unload_model_weights(sd_model=None, info=None):
-    timer = Timer()
-
-    if model_data.sd_model:
-        model_data.sd_model.to(devices.cpu)
-        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
-        model_data.sd_model = None
-        sd_model = None
-        gc.collect()
-        devices.torch_gc()
-
-    print(f"Unloaded weights {timer.summary()}.")
+    send_model_to_cpu(sd_model or shared.sd_model)
 
     return sd_model
 

+ 17 - 7
modules/ui_settings.py

@@ -1,6 +1,6 @@
 import gradio as gr
 
-from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
+from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
 from modules.call_queue import wrap_gradio_call
 from modules.shared import opts
 from modules.ui_components import FormRow
@@ -177,8 +177,8 @@ class UiSettings:
                     download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
                     reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
                     with gr.Row():
-                        unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
-                        reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
+                        unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model")
+                        reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model")
                     with gr.Row():
                         calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
                         calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
@@ -194,16 +194,26 @@ class UiSettings:
 
                 self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
 
+            def call_func_and_return_text(func, text):
+                def handler():
+                    t = timer.Timer()
+                    func()
+                    t.record(text)
+
+                    return f'{text} in {t.total:.1f}s'
+
+                return handler
+
             unload_sd_model.click(
-                fn=sd_models.unload_model_weights,
+                fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'),
                 inputs=[],
-                outputs=[]
+                outputs=[self.result]
             )
 
             reload_sd_model.click(
-                fn=sd_models.reload_model_weights,
+                fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'),
                 inputs=[],
-                outputs=[]
+                outputs=[self.result]
             )
 
             request_notifications.click(