|
@@ -17,15 +17,14 @@ from fastapi.encoders import jsonable_encoder
|
|
|
from secrets import compare_digest
|
|
|
|
|
|
import modules.shared as shared
|
|
|
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste
|
|
|
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
|
|
|
from modules.api import models
|
|
|
from modules.shared import opts
|
|
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
|
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
|
|
from modules.textual_inversion.preprocess import preprocess
|
|
|
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
|
|
|
-from PIL import PngImagePlugin,Image
|
|
|
-from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
|
|
|
+from PIL import PngImagePlugin, Image
|
|
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
|
|
from modules.realesrgan_model import get_realesrgan_models
|
|
|
from modules import devices
|
|
@@ -541,12 +540,12 @@ class Api:
|
|
|
return {}
|
|
|
|
|
|
def unloadapi(self):
|
|
|
- unload_model_weights()
|
|
|
+ sd_models.unload_model_weights()
|
|
|
|
|
|
return {}
|
|
|
|
|
|
def reloadapi(self):
|
|
|
- reload_model_weights()
|
|
|
+ sd_models.send_model_to_device(shared.sd_model)
|
|
|
|
|
|
return {}
|
|
|
|
|
@@ -566,7 +565,7 @@ class Api:
|
|
|
|
|
|
def set_config(self, req: dict[str, Any]):
|
|
|
checkpoint_name = req.get("sd_model_checkpoint", None)
|
|
|
- if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
|
|
|
+ if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
|
|
|
raise RuntimeError(f"model {checkpoint_name!r} not found")
|
|
|
|
|
|
for k, v in req.items():
|