api.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685
  1. import base64
  2. import io
  3. import time
  4. import datetime
  5. import uvicorn
  6. import gradio as gr
  7. from threading import Lock
  8. from io import BytesIO
  9. from fastapi import APIRouter, Depends, FastAPI, Request, Response
  10. from fastapi.security import HTTPBasic, HTTPBasicCredentials
  11. from fastapi.exceptions import HTTPException
  12. from fastapi.responses import JSONResponse
  13. from fastapi.encoders import jsonable_encoder
  14. from secrets import compare_digest
  15. import modules.shared as shared
  16. from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
  17. from modules.api.models import *
  18. from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
  19. from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
  20. from modules.textual_inversion.preprocess import preprocess
  21. from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
  22. from PIL import PngImagePlugin,Image
  23. from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
  24. from modules.sd_models_config import find_checkpoint_config_near_filename
  25. from modules.realesrgan_model import get_realesrgan_models
  26. from modules import devices
  27. from typing import List
  28. import piexif
  29. import piexif.helper
  30. def upscaler_to_index(name: str):
  31. try:
  32. return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
  33. except:
  34. raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
  35. def script_name_to_index(name, scripts):
  36. try:
  37. return [script.title().lower() for script in scripts].index(name.lower())
  38. except:
  39. raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
  40. def validate_sampler_name(name):
  41. config = sd_samplers.all_samplers_map.get(name, None)
  42. if config is None:
  43. raise HTTPException(status_code=404, detail="Sampler not found")
  44. return name
  45. def setUpscalers(req: dict):
  46. reqDict = vars(req)
  47. reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
  48. reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
  49. return reqDict
  50. def decode_base64_to_image(encoding):
  51. if encoding.startswith("data:image/"):
  52. encoding = encoding.split(";")[1].split(",")[1]
  53. try:
  54. image = Image.open(BytesIO(base64.b64decode(encoding)))
  55. return image
  56. except Exception as err:
  57. raise HTTPException(status_code=500, detail="Invalid encoded image")
  58. def encode_pil_to_base64(image):
  59. with io.BytesIO() as output_bytes:
  60. if opts.samples_format.lower() == 'png':
  61. use_metadata = False
  62. metadata = PngImagePlugin.PngInfo()
  63. for key, value in image.info.items():
  64. if isinstance(key, str) and isinstance(value, str):
  65. metadata.add_text(key, value)
  66. use_metadata = True
  67. image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
  68. elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
  69. parameters = image.info.get('parameters', None)
  70. exif_bytes = piexif.dump({
  71. "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
  72. })
  73. if opts.samples_format.lower() in ("jpg", "jpeg"):
  74. image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
  75. else:
  76. image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
  77. else:
  78. raise HTTPException(status_code=500, detail="Invalid image format")
  79. bytes_data = output_bytes.getvalue()
  80. return base64.b64encode(bytes_data)
  81. def api_middleware(app: FastAPI):
  82. rich_available = True
  83. try:
  84. import anyio # importing just so it can be placed on silent list
  85. import starlette # importing just so it can be placed on silent list
  86. from rich.console import Console
  87. console = Console()
  88. except:
  89. import traceback
  90. rich_available = False
  91. @app.middleware("http")
  92. async def log_and_time(req: Request, call_next):
  93. ts = time.time()
  94. res: Response = await call_next(req)
  95. duration = str(round(time.time() - ts, 4))
  96. res.headers["X-Process-Time"] = duration
  97. endpoint = req.scope.get('path', 'err')
  98. if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
  99. print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
  100. t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
  101. code = res.status_code,
  102. ver = req.scope.get('http_version', '0.0'),
  103. cli = req.scope.get('client', ('0:0.0.0', 0))[0],
  104. prot = req.scope.get('scheme', 'err'),
  105. method = req.scope.get('method', 'err'),
  106. endpoint = endpoint,
  107. duration = duration,
  108. ))
  109. return res
  110. def handle_exception(request: Request, e: Exception):
  111. err = {
  112. "error": type(e).__name__,
  113. "detail": vars(e).get('detail', ''),
  114. "body": vars(e).get('body', ''),
  115. "errors": str(e),
  116. }
  117. print(f"API error: {request.method}: {request.url} {err}")
  118. if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
  119. if rich_available:
  120. console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
  121. else:
  122. traceback.print_exc()
  123. return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
  124. @app.middleware("http")
  125. async def exception_handling(request: Request, call_next):
  126. try:
  127. return await call_next(request)
  128. except Exception as e:
  129. return handle_exception(request, e)
  130. @app.exception_handler(Exception)
  131. async def fastapi_exception_handler(request: Request, e: Exception):
  132. return handle_exception(request, e)
  133. @app.exception_handler(HTTPException)
  134. async def http_exception_handler(request: Request, e: HTTPException):
  135. return handle_exception(request, e)
  136. class Api:
  137. def __init__(self, app: FastAPI, queue_lock: Lock):
  138. if shared.cmd_opts.api_auth:
  139. self.credentials = dict()
  140. for auth in shared.cmd_opts.api_auth.split(","):
  141. user, password = auth.split(":")
  142. self.credentials[user] = password
  143. self.router = APIRouter()
  144. self.app = app
  145. self.queue_lock = queue_lock
  146. api_middleware(self.app)
  147. self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
  148. self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
  149. self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
  150. self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
  151. self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
  152. self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
  153. self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
  154. self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
  155. self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
  156. self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
  157. self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
  158. self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
  159. self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
  160. self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
  161. self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
  162. self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
  163. self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
  164. self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
  165. self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
  166. self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
  167. self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
  168. self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
  169. self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
  170. self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
  171. self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
  172. self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
  173. self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
  174. self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
  175. self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
  176. self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
  177. self.default_script_arg_txt2img = []
  178. self.default_script_arg_img2img = []
  179. def add_api_route(self, path: str, endpoint, **kwargs):
  180. if shared.cmd_opts.api_auth:
  181. return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
  182. return self.app.add_api_route(path, endpoint, **kwargs)
  183. def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
  184. if credentials.username in self.credentials:
  185. if compare_digest(credentials.password, self.credentials[credentials.username]):
  186. return True
  187. raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
  188. def get_selectable_script(self, script_name, script_runner):
  189. if script_name is None or script_name == "":
  190. return None, None
  191. script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
  192. script = script_runner.selectable_scripts[script_idx]
  193. return script, script_idx
  194. def get_scripts_list(self):
  195. t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
  196. i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
  197. return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
  198. def get_script(self, script_name, script_runner):
  199. if script_name is None or script_name == "":
  200. return None, None
  201. script_idx = script_name_to_index(script_name, script_runner.scripts)
  202. return script_runner.scripts[script_idx]
  203. def init_default_script_args(self, script_runner):
  204. #find max idx from the scripts in runner and generate a none array to init script_args
  205. last_arg_index = 1
  206. for script in script_runner.scripts:
  207. if last_arg_index < script.args_to:
  208. last_arg_index = script.args_to
  209. # None everywhere except position 0 to initialize script args
  210. script_args = [None]*last_arg_index
  211. script_args[0] = 0
  212. # get default values
  213. with gr.Blocks(): # will throw errors calling ui function without this
  214. for script in script_runner.scripts:
  215. if script.ui(script.is_img2img):
  216. ui_default_values = []
  217. for elem in script.ui(script.is_img2img):
  218. ui_default_values.append(elem.value)
  219. script_args[script.args_from:script.args_to] = ui_default_values
  220. return script_args
  221. def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner):
  222. script_args = default_script_args.copy()
  223. # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
  224. if selectable_scripts:
  225. script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
  226. script_args[0] = selectable_idx + 1
  227. # Now check for always on scripts
  228. if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
  229. for alwayson_script_name in request.alwayson_scripts.keys():
  230. alwayson_script = self.get_script(alwayson_script_name, script_runner)
  231. if alwayson_script == None:
  232. raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
  233. # Selectable script in always on script param check
  234. if alwayson_script.alwayson == False:
  235. raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
  236. # always on script with no arg should always run so you don't really need to add them to the requests
  237. if "args" in request.alwayson_scripts[alwayson_script_name]:
  238. # min between arg length in scriptrunner and arg length in the request
  239. for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
  240. script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
  241. return script_args
  242. def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
  243. script_runner = scripts.scripts_txt2img
  244. if not script_runner.scripts:
  245. script_runner.initialize_scripts(False)
  246. ui.create_ui()
  247. if not self.default_script_arg_txt2img:
  248. self.default_script_arg_txt2img = self.init_default_script_args(script_runner)
  249. selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
  250. populate = txt2imgreq.copy(update={ # Override __init__ params
  251. "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
  252. "do_not_save_samples": not txt2imgreq.save_images,
  253. "do_not_save_grid": not txt2imgreq.save_images,
  254. })
  255. if populate.sampler_name:
  256. populate.sampler_index = None # prevent a warning later on
  257. args = vars(populate)
  258. args.pop('script_name', None)
  259. args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
  260. args.pop('alwayson_scripts', None)
  261. script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
  262. send_images = args.pop('send_images', True)
  263. args.pop('save_images', None)
  264. with self.queue_lock:
  265. p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
  266. p.scripts = script_runner
  267. p.outpath_grids = opts.outdir_txt2img_grids
  268. p.outpath_samples = opts.outdir_txt2img_samples
  269. shared.state.begin()
  270. if selectable_scripts != None:
  271. p.script_args = script_args
  272. processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
  273. else:
  274. p.script_args = tuple(script_args) # Need to pass args as tuple here
  275. processed = process_images(p)
  276. shared.state.end()
  277. b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
  278. return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
  279. def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
  280. init_images = img2imgreq.init_images
  281. if init_images is None:
  282. raise HTTPException(status_code=404, detail="Init image not found")
  283. mask = img2imgreq.mask
  284. if mask:
  285. mask = decode_base64_to_image(mask)
  286. script_runner = scripts.scripts_img2img
  287. if not script_runner.scripts:
  288. script_runner.initialize_scripts(True)
  289. ui.create_ui()
  290. if not self.default_script_arg_img2img:
  291. self.default_script_arg_img2img = self.init_default_script_args(script_runner)
  292. selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
  293. populate = img2imgreq.copy(update={ # Override __init__ params
  294. "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
  295. "do_not_save_samples": not img2imgreq.save_images,
  296. "do_not_save_grid": not img2imgreq.save_images,
  297. "mask": mask,
  298. })
  299. if populate.sampler_name:
  300. populate.sampler_index = None # prevent a warning later on
  301. args = vars(populate)
  302. args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
  303. args.pop('script_name', None)
  304. args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
  305. args.pop('alwayson_scripts', None)
  306. script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
  307. send_images = args.pop('send_images', True)
  308. args.pop('save_images', None)
  309. with self.queue_lock:
  310. p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
  311. p.init_images = [decode_base64_to_image(x) for x in init_images]
  312. p.scripts = script_runner
  313. p.outpath_grids = opts.outdir_img2img_grids
  314. p.outpath_samples = opts.outdir_img2img_samples
  315. shared.state.begin()
  316. if selectable_scripts != None:
  317. p.script_args = script_args
  318. processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
  319. else:
  320. p.script_args = tuple(script_args) # Need to pass args as tuple here
  321. processed = process_images(p)
  322. shared.state.end()
  323. b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
  324. if not img2imgreq.include_init_images:
  325. img2imgreq.init_images = None
  326. img2imgreq.mask = None
  327. return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
  328. def extras_single_image_api(self, req: ExtrasSingleImageRequest):
  329. reqDict = setUpscalers(req)
  330. reqDict['image'] = decode_base64_to_image(reqDict['image'])
  331. with self.queue_lock:
  332. result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
  333. return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
  334. def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
  335. reqDict = setUpscalers(req)
  336. image_list = reqDict.pop('imageList', [])
  337. image_folder = [decode_base64_to_image(x.data) for x in image_list]
  338. with self.queue_lock:
  339. result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
  340. return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
  341. def pnginfoapi(self, req: PNGInfoRequest):
  342. if(not req.image.strip()):
  343. return PNGInfoResponse(info="")
  344. image = decode_base64_to_image(req.image.strip())
  345. if image is None:
  346. return PNGInfoResponse(info="")
  347. geninfo, items = images.read_info_from_image(image)
  348. if geninfo is None:
  349. geninfo = ""
  350. items = {**{'parameters': geninfo}, **items}
  351. return PNGInfoResponse(info=geninfo, items=items)
  352. def progressapi(self, req: ProgressRequest = Depends()):
  353. # copy from check_progress_call of ui.py
  354. if shared.state.job_count == 0:
  355. return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
  356. # avoid dividing zero
  357. progress = 0.01
  358. if shared.state.job_count > 0:
  359. progress += shared.state.job_no / shared.state.job_count
  360. if shared.state.sampling_steps > 0:
  361. progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
  362. time_since_start = time.time() - shared.state.time_start
  363. eta = (time_since_start/progress)
  364. eta_relative = eta-time_since_start
  365. progress = min(progress, 1)
  366. shared.state.set_current_image()
  367. current_image = None
  368. if shared.state.current_image and not req.skip_current_image:
  369. current_image = encode_pil_to_base64(shared.state.current_image)
  370. return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
  371. def interrogateapi(self, interrogatereq: InterrogateRequest):
  372. image_b64 = interrogatereq.image
  373. if image_b64 is None:
  374. raise HTTPException(status_code=404, detail="Image not found")
  375. img = decode_base64_to_image(image_b64)
  376. img = img.convert('RGB')
  377. # Override object param
  378. with self.queue_lock:
  379. if interrogatereq.model == "clip":
  380. processed = shared.interrogator.interrogate(img)
  381. elif interrogatereq.model == "deepdanbooru":
  382. processed = deepbooru.model.tag(img)
  383. else:
  384. raise HTTPException(status_code=404, detail="Model not found")
  385. return InterrogateResponse(caption=processed)
  386. def interruptapi(self):
  387. shared.state.interrupt()
  388. return {}
  389. def unloadapi(self):
  390. unload_model_weights()
  391. return {}
  392. def reloadapi(self):
  393. reload_model_weights()
  394. return {}
  395. def skip(self):
  396. shared.state.skip()
  397. def get_config(self):
  398. options = {}
  399. for key in shared.opts.data.keys():
  400. metadata = shared.opts.data_labels.get(key)
  401. if(metadata is not None):
  402. options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
  403. else:
  404. options.update({key: shared.opts.data.get(key, None)})
  405. return options
  406. def set_config(self, req: Dict[str, Any]):
  407. for k, v in req.items():
  408. shared.opts.set(k, v)
  409. shared.opts.save(shared.config_filename)
  410. return
  411. def get_cmd_flags(self):
  412. return vars(shared.cmd_opts)
  413. def get_samplers(self):
  414. return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
  415. def get_upscalers(self):
  416. return [
  417. {
  418. "name": upscaler.name,
  419. "model_name": upscaler.scaler.model_name,
  420. "model_path": upscaler.data_path,
  421. "model_url": None,
  422. "scale": upscaler.scale,
  423. }
  424. for upscaler in shared.sd_upscalers
  425. ]
  426. def get_sd_models(self):
  427. return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()]
  428. def get_hypernetworks(self):
  429. return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
  430. def get_face_restorers(self):
  431. return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
  432. def get_realesrgan_models(self):
  433. return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
  434. def get_prompt_styles(self):
  435. styleList = []
  436. for k in shared.prompt_styles.styles:
  437. style = shared.prompt_styles.styles[k]
  438. styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
  439. return styleList
  440. def get_embeddings(self):
  441. db = sd_hijack.model_hijack.embedding_db
  442. def convert_embedding(embedding):
  443. return {
  444. "step": embedding.step,
  445. "sd_checkpoint": embedding.sd_checkpoint,
  446. "sd_checkpoint_name": embedding.sd_checkpoint_name,
  447. "shape": embedding.shape,
  448. "vectors": embedding.vectors,
  449. }
  450. def convert_embeddings(embeddings):
  451. return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
  452. return {
  453. "loaded": convert_embeddings(db.word_embeddings),
  454. "skipped": convert_embeddings(db.skipped_embeddings),
  455. }
  456. def refresh_checkpoints(self):
  457. shared.refresh_checkpoints()
  458. def create_embedding(self, args: dict):
  459. try:
  460. shared.state.begin()
  461. filename = create_embedding(**args) # create empty embedding
  462. sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
  463. shared.state.end()
  464. return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
  465. except AssertionError as e:
  466. shared.state.end()
  467. return TrainResponse(info = "create embedding error: {error}".format(error = e))
  468. def create_hypernetwork(self, args: dict):
  469. try:
  470. shared.state.begin()
  471. filename = create_hypernetwork(**args) # create empty embedding
  472. shared.state.end()
  473. return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
  474. except AssertionError as e:
  475. shared.state.end()
  476. return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
  477. def preprocess(self, args: dict):
  478. try:
  479. shared.state.begin()
  480. preprocess(**args) # quick operation unless blip/booru interrogation is enabled
  481. shared.state.end()
  482. return PreprocessResponse(info = 'preprocess complete')
  483. except KeyError as e:
  484. shared.state.end()
  485. return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
  486. except AssertionError as e:
  487. shared.state.end()
  488. return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
  489. except FileNotFoundError as e:
  490. shared.state.end()
  491. return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
  492. def train_embedding(self, args: dict):
  493. try:
  494. shared.state.begin()
  495. apply_optimizations = shared.opts.training_xattention_optimizations
  496. error = None
  497. filename = ''
  498. if not apply_optimizations:
  499. sd_hijack.undo_optimizations()
  500. try:
  501. embedding, filename = train_embedding(**args) # can take a long time to complete
  502. except Exception as e:
  503. error = e
  504. finally:
  505. if not apply_optimizations:
  506. sd_hijack.apply_optimizations()
  507. shared.state.end()
  508. return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
  509. except AssertionError as msg:
  510. shared.state.end()
  511. return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
  512. def train_hypernetwork(self, args: dict):
  513. try:
  514. shared.state.begin()
  515. shared.loaded_hypernetworks = []
  516. apply_optimizations = shared.opts.training_xattention_optimizations
  517. error = None
  518. filename = ''
  519. if not apply_optimizations:
  520. sd_hijack.undo_optimizations()
  521. try:
  522. hypernetwork, filename = train_hypernetwork(**args)
  523. except Exception as e:
  524. error = e
  525. finally:
  526. shared.sd_model.cond_stage_model.to(devices.device)
  527. shared.sd_model.first_stage_model.to(devices.device)
  528. if not apply_optimizations:
  529. sd_hijack.apply_optimizations()
  530. shared.state.end()
  531. return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
  532. except AssertionError as msg:
  533. shared.state.end()
  534. return TrainResponse(info="train embedding error: {error}".format(error=error))
  535. def get_memory(self):
  536. try:
  537. import os, psutil
  538. process = psutil.Process(os.getpid())
  539. res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
  540. ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
  541. ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
  542. except Exception as err:
  543. ram = { 'error': f'{err}' }
  544. try:
  545. import torch
  546. if torch.cuda.is_available():
  547. s = torch.cuda.mem_get_info()
  548. system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
  549. s = dict(torch.cuda.memory_stats(shared.device))
  550. allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
  551. reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
  552. active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
  553. inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
  554. warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
  555. cuda = {
  556. 'system': system,
  557. 'active': active,
  558. 'allocated': allocated,
  559. 'reserved': reserved,
  560. 'inactive': inactive,
  561. 'events': warnings,
  562. }
  563. else:
  564. cuda = { 'error': 'unavailable' }
  565. except Exception as err:
  566. cuda = { 'error': f'{err}' }
  567. return MemoryResponse(ram = ram, cuda = cuda)
  568. def launch(self, server_name, port):
  569. self.app.include_router(self.router)
  570. uvicorn.run(self.app, host=server_name, port=port)