|
@@ -6,7 +6,6 @@ import uvicorn
|
|
|
import gradio as gr
|
|
|
from threading import Lock
|
|
|
from io import BytesIO
|
|
|
-from gradio.processing_utils import decode_base64_to_file
|
|
|
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
|
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|
|
from fastapi.exceptions import HTTPException
|
|
@@ -16,7 +15,8 @@ from secrets import compare_digest
|
|
|
|
|
|
import modules.shared as shared
|
|
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
|
|
-from modules.api.models import *
|
|
|
+from modules.api import models
|
|
|
+from modules.shared import opts
|
|
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
|
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
|
|
from modules.textual_inversion.preprocess import preprocess
|
|
@@ -26,21 +26,24 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod
|
|
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
|
|
from modules.realesrgan_model import get_realesrgan_models
|
|
|
from modules import devices
|
|
|
-from typing import List
|
|
|
+from typing import Dict, List, Any
|
|
|
import piexif
|
|
|
import piexif.helper
|
|
|
|
|
|
+
|
|
|
def upscaler_to_index(name: str):
|
|
|
try:
|
|
|
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
|
|
- except:
|
|
|
- raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
|
|
|
+ except Exception as e:
|
|
|
+ raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
|
|
|
+
|
|
|
|
|
|
def script_name_to_index(name, scripts):
|
|
|
try:
|
|
|
return [script.title().lower() for script in scripts].index(name.lower())
|
|
|
- except:
|
|
|
- raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
|
|
+ except Exception as e:
|
|
|
+ raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
|
|
|
+
|
|
|
|
|
|
def validate_sampler_name(name):
|
|
|
config = sd_samplers.all_samplers_map.get(name, None)
|
|
@@ -49,20 +52,23 @@ def validate_sampler_name(name):
|
|
|
|
|
|
return name
|
|
|
|
|
|
+
|
|
|
def setUpscalers(req: dict):
|
|
|
reqDict = vars(req)
|
|
|
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
|
|
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
|
|
return reqDict
|
|
|
|
|
|
+
|
|
|
def decode_base64_to_image(encoding):
|
|
|
if encoding.startswith("data:image/"):
|
|
|
encoding = encoding.split(";")[1].split(",")[1]
|
|
|
try:
|
|
|
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
|
|
return image
|
|
|
- except Exception as err:
|
|
|
- raise HTTPException(status_code=500, detail="Invalid encoded image")
|
|
|
+ except Exception as e:
|
|
|
+ raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
|
|
+
|
|
|
|
|
|
def encode_pil_to_base64(image):
|
|
|
with io.BytesIO() as output_bytes:
|
|
@@ -93,6 +99,7 @@ def encode_pil_to_base64(image):
|
|
|
|
|
|
return base64.b64encode(bytes_data)
|
|
|
|
|
|
+
|
|
|
def api_middleware(app: FastAPI):
|
|
|
rich_available = True
|
|
|
try:
|
|
@@ -100,7 +107,7 @@ def api_middleware(app: FastAPI):
|
|
|
import starlette # importing just so it can be placed on silent list
|
|
|
from rich.console import Console
|
|
|
console = Console()
|
|
|
- except:
|
|
|
+ except Exception:
|
|
|
import traceback
|
|
|
rich_available = False
|
|
|
|
|
@@ -131,8 +138,8 @@ def api_middleware(app: FastAPI):
|
|
|
"body": vars(e).get('body', ''),
|
|
|
"errors": str(e),
|
|
|
}
|
|
|
- print(f"API error: {request.method}: {request.url} {err}")
|
|
|
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
|
|
+ print(f"API error: {request.method}: {request.url} {err}")
|
|
|
if rich_available:
|
|
|
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
|
|
else:
|
|
@@ -158,7 +165,7 @@ def api_middleware(app: FastAPI):
|
|
|
class Api:
|
|
|
def __init__(self, app: FastAPI, queue_lock: Lock):
|
|
|
if shared.cmd_opts.api_auth:
|
|
|
- self.credentials = dict()
|
|
|
+ self.credentials = {}
|
|
|
for auth in shared.cmd_opts.api_auth.split(","):
|
|
|
user, password = auth.split(":")
|
|
|
self.credentials[user] = password
|
|
@@ -167,36 +174,37 @@ class Api:
|
|
|
self.app = app
|
|
|
self.queue_lock = queue_lock
|
|
|
api_middleware(self.app)
|
|
|
- self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
|
|
- self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
|
|
- self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
|
|
- self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
|
|
- self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
|
|
- self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
|
|
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
|
|
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
|
|
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
|
|
- self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
|
|
+ self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
|
|
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
|
|
- self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
|
|
- self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
|
|
- self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
|
|
- self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
|
|
- self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
|
|
- self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
|
|
- self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
|
|
- self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
|
|
- self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
|
|
+ self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
|
|
+ self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
|
|
+ self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
|
|
+ self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
|
|
+ self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
|
|
+ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
|
|
+ self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
|
|
+ self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
|
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
|
|
- self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
|
|
- self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
|
|
- self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
|
|
- self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
|
|
- self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
|
|
- self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
|
|
+ self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
|
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
|
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
|
|
- self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
|
|
+ self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
|
|
+ self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
|
|
|
|
|
self.default_script_arg_txt2img = []
|
|
|
self.default_script_arg_img2img = []
|
|
@@ -220,17 +228,25 @@ class Api:
|
|
|
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
|
|
script = script_runner.selectable_scripts[script_idx]
|
|
|
return script, script_idx
|
|
|
-
|
|
|
+
|
|
|
def get_scripts_list(self):
|
|
|
- t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
|
|
- i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
|
|
+ t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
|
|
|
+ i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
|
|
|
|
|
|
- return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
|
|
|
+ return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
|
|
|
+
|
|
|
+ def get_script_info(self):
|
|
|
+ res = []
|
|
|
+
|
|
|
+ for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
|
|
|
+ res += [script.api_info for script in script_list if script.api_info is not None]
|
|
|
+
|
|
|
+ return res
|
|
|
|
|
|
def get_script(self, script_name, script_runner):
|
|
|
if script_name is None or script_name == "":
|
|
|
return None, None
|
|
|
-
|
|
|
+
|
|
|
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
|
|
return script_runner.scripts[script_idx]
|
|
|
|
|
@@ -265,17 +281,19 @@ class Api:
|
|
|
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
|
|
for alwayson_script_name in request.alwayson_scripts.keys():
|
|
|
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
|
|
- if alwayson_script == None:
|
|
|
+ if alwayson_script is None:
|
|
|
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
|
|
# Selectable script in always on script param check
|
|
|
- if alwayson_script.alwayson == False:
|
|
|
- raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
|
|
|
+ if alwayson_script.alwayson is False:
|
|
|
+ raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
|
|
|
# always on script with no arg should always run so you don't really need to add them to the requests
|
|
|
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
|
|
- script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
|
|
|
+ # min between arg length in scriptrunner and arg length in the request
|
|
|
+ for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
|
|
|
+ script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
|
|
return script_args
|
|
|
|
|
|
- def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
|
|
+ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
|
|
script_runner = scripts.scripts_txt2img
|
|
|
if not script_runner.scripts:
|
|
|
script_runner.initialize_scripts(False)
|
|
@@ -309,7 +327,7 @@ class Api:
|
|
|
p.outpath_samples = opts.outdir_txt2img_samples
|
|
|
|
|
|
shared.state.begin()
|
|
|
- if selectable_scripts != None:
|
|
|
+ if selectable_scripts is not None:
|
|
|
p.script_args = script_args
|
|
|
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
|
|
else:
|
|
@@ -319,9 +337,9 @@ class Api:
|
|
|
|
|
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
|
|
|
|
|
- return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
|
|
+ return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
|
|
|
|
|
- def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
|
|
+ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
|
|
init_images = img2imgreq.init_images
|
|
|
if init_images is None:
|
|
|
raise HTTPException(status_code=404, detail="Init image not found")
|
|
@@ -366,7 +384,7 @@ class Api:
|
|
|
p.outpath_samples = opts.outdir_img2img_samples
|
|
|
|
|
|
shared.state.begin()
|
|
|
- if selectable_scripts != None:
|
|
|
+ if selectable_scripts is not None:
|
|
|
p.script_args = script_args
|
|
|
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
|
|
else:
|
|
@@ -380,9 +398,9 @@ class Api:
|
|
|
img2imgreq.init_images = None
|
|
|
img2imgreq.mask = None
|
|
|
|
|
|
- return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
|
|
+ return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
|
|
|
|
|
- def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
|
|
+ def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
|
|
|
reqDict = setUpscalers(req)
|
|
|
|
|
|
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
|
@@ -390,31 +408,26 @@ class Api:
|
|
|
with self.queue_lock:
|
|
|
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
|
|
|
|
|
- return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
|
|
+ return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
|
|
|
|
|
- def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
|
|
+ def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
|
|
|
reqDict = setUpscalers(req)
|
|
|
|
|
|
- def prepareFiles(file):
|
|
|
- file = decode_base64_to_file(file.data, file_path=file.name)
|
|
|
- file.orig_name = file.name
|
|
|
- return file
|
|
|
-
|
|
|
- reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
|
|
|
- reqDict.pop('imageList')
|
|
|
+ image_list = reqDict.pop('imageList', [])
|
|
|
+ image_folder = [decode_base64_to_image(x.data) for x in image_list]
|
|
|
|
|
|
with self.queue_lock:
|
|
|
- result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
|
|
+ result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
|
|
|
|
|
- return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
|
|
+ return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
|
|
|
|
|
- def pnginfoapi(self, req: PNGInfoRequest):
|
|
|
+ def pnginfoapi(self, req: models.PNGInfoRequest):
|
|
|
if(not req.image.strip()):
|
|
|
- return PNGInfoResponse(info="")
|
|
|
+ return models.PNGInfoResponse(info="")
|
|
|
|
|
|
image = decode_base64_to_image(req.image.strip())
|
|
|
if image is None:
|
|
|
- return PNGInfoResponse(info="")
|
|
|
+ return models.PNGInfoResponse(info="")
|
|
|
|
|
|
geninfo, items = images.read_info_from_image(image)
|
|
|
if geninfo is None:
|
|
@@ -422,13 +435,13 @@ class Api:
|
|
|
|
|
|
items = {**{'parameters': geninfo}, **items}
|
|
|
|
|
|
- return PNGInfoResponse(info=geninfo, items=items)
|
|
|
+ return models.PNGInfoResponse(info=geninfo, items=items)
|
|
|
|
|
|
- def progressapi(self, req: ProgressRequest = Depends()):
|
|
|
+ def progressapi(self, req: models.ProgressRequest = Depends()):
|
|
|
# copy from check_progress_call of ui.py
|
|
|
|
|
|
if shared.state.job_count == 0:
|
|
|
- return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
|
|
+ return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
|
|
|
|
|
# avoid dividing zero
|
|
|
progress = 0.01
|
|
@@ -450,9 +463,9 @@ class Api:
|
|
|
if shared.state.current_image and not req.skip_current_image:
|
|
|
current_image = encode_pil_to_base64(shared.state.current_image)
|
|
|
|
|
|
- return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
|
|
+ return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
|
|
|
|
|
- def interrogateapi(self, interrogatereq: InterrogateRequest):
|
|
|
+ def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
|
|
image_b64 = interrogatereq.image
|
|
|
if image_b64 is None:
|
|
|
raise HTTPException(status_code=404, detail="Image not found")
|
|
@@ -469,7 +482,7 @@ class Api:
|
|
|
else:
|
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
|
|
|
|
- return InterrogateResponse(caption=processed)
|
|
|
+ return models.InterrogateResponse(caption=processed)
|
|
|
|
|
|
def interruptapi(self):
|
|
|
shared.state.interrupt()
|
|
@@ -574,36 +587,36 @@ class Api:
|
|
|
filename = create_embedding(**args) # create empty embedding
|
|
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
|
|
shared.state.end()
|
|
|
- return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
|
|
|
+ return models.CreateResponse(info=f"create embedding filename: {filename}")
|
|
|
except AssertionError as e:
|
|
|
shared.state.end()
|
|
|
- return TrainResponse(info = "create embedding error: {error}".format(error = e))
|
|
|
+ return models.TrainResponse(info=f"create embedding error: {e}")
|
|
|
|
|
|
def create_hypernetwork(self, args: dict):
|
|
|
try:
|
|
|
shared.state.begin()
|
|
|
filename = create_hypernetwork(**args) # create empty embedding
|
|
|
shared.state.end()
|
|
|
- return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
|
|
|
+ return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
|
|
except AssertionError as e:
|
|
|
shared.state.end()
|
|
|
- return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
|
|
|
+ return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
|
|
|
|
|
def preprocess(self, args: dict):
|
|
|
try:
|
|
|
shared.state.begin()
|
|
|
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
|
|
shared.state.end()
|
|
|
- return PreprocessResponse(info = 'preprocess complete')
|
|
|
+ return models.PreprocessResponse(info = 'preprocess complete')
|
|
|
except KeyError as e:
|
|
|
shared.state.end()
|
|
|
- return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
|
|
|
+ return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
|
|
except AssertionError as e:
|
|
|
shared.state.end()
|
|
|
- return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
|
|
|
+ return models.PreprocessResponse(info=f"preprocess error: {e}")
|
|
|
except FileNotFoundError as e:
|
|
|
shared.state.end()
|
|
|
- return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
|
|
|
+ return models.PreprocessResponse(info=f'preprocess error: {e}')
|
|
|
|
|
|
def train_embedding(self, args: dict):
|
|
|
try:
|
|
@@ -621,10 +634,10 @@ class Api:
|
|
|
if not apply_optimizations:
|
|
|
sd_hijack.apply_optimizations()
|
|
|
shared.state.end()
|
|
|
- return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
|
|
+ return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
|
|
except AssertionError as msg:
|
|
|
shared.state.end()
|
|
|
- return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
|
|
|
+ return models.TrainResponse(info=f"train embedding error: {msg}")
|
|
|
|
|
|
def train_hypernetwork(self, args: dict):
|
|
|
try:
|
|
@@ -645,14 +658,15 @@ class Api:
|
|
|
if not apply_optimizations:
|
|
|
sd_hijack.apply_optimizations()
|
|
|
shared.state.end()
|
|
|
- return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
|
|
|
- except AssertionError as msg:
|
|
|
+ return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
|
|
+ except AssertionError:
|
|
|
shared.state.end()
|
|
|
- return TrainResponse(info="train embedding error: {error}".format(error=error))
|
|
|
+ return models.TrainResponse(info=f"train embedding error: {error}")
|
|
|
|
|
|
def get_memory(self):
|
|
|
try:
|
|
|
- import os, psutil
|
|
|
+ import os
|
|
|
+ import psutil
|
|
|
process = psutil.Process(os.getpid())
|
|
|
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
|
|
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
|
@@ -679,10 +693,10 @@ class Api:
|
|
|
'events': warnings,
|
|
|
}
|
|
|
else:
|
|
|
- cuda = { 'error': 'unavailable' }
|
|
|
+ cuda = {'error': 'unavailable'}
|
|
|
except Exception as err:
|
|
|
- cuda = { 'error': f'{err}' }
|
|
|
- return MemoryResponse(ram = ram, cuda = cuda)
|
|
|
+ cuda = {'error': f'{err}'}
|
|
|
+ return models.MemoryResponse(ram=ram, cuda=cuda)
|
|
|
|
|
|
def launch(self, server_name, port):
|
|
|
self.app.include_router(self.router)
|