api.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. import base64
  2. import io
  3. import os
  4. import time
  5. import datetime
  6. import uvicorn
  7. import ipaddress
  8. import requests
  9. import gradio as gr
  10. from threading import Lock
  11. from io import BytesIO
  12. from fastapi import APIRouter, Depends, FastAPI, Request, Response
  13. from fastapi.security import HTTPBasic, HTTPBasicCredentials
  14. from fastapi.exceptions import HTTPException
  15. from fastapi.responses import JSONResponse
  16. from fastapi.encoders import jsonable_encoder
  17. from secrets import compare_digest
  18. import modules.shared as shared
  19. from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
  20. from modules.api import models
  21. from modules.shared import opts
  22. from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
  23. from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
  24. from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
  25. from PIL import PngImagePlugin
  26. from modules.sd_models_config import find_checkpoint_config_near_filename
  27. from modules.realesrgan_model import get_realesrgan_models
  28. from modules import devices
  29. from typing import Any
  30. import piexif
  31. import piexif.helper
  32. from contextlib import closing
  33. from modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task
  34. def script_name_to_index(name, scripts):
  35. try:
  36. return [script.title().lower() for script in scripts].index(name.lower())
  37. except Exception as e:
  38. raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
  39. def validate_sampler_name(name):
  40. config = sd_samplers.all_samplers_map.get(name, None)
  41. if config is None:
  42. raise HTTPException(status_code=404, detail="Sampler not found")
  43. return name
  44. def setUpscalers(req: dict):
  45. reqDict = vars(req)
  46. reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
  47. reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
  48. return reqDict
  49. def verify_url(url):
  50. """Returns True if the url refers to a global resource."""
  51. import socket
  52. from urllib.parse import urlparse
  53. try:
  54. parsed_url = urlparse(url)
  55. domain_name = parsed_url.netloc
  56. host = socket.gethostbyname_ex(domain_name)
  57. for ip in host[2]:
  58. ip_addr = ipaddress.ip_address(ip)
  59. if not ip_addr.is_global:
  60. return False
  61. except Exception:
  62. return False
  63. return True
  64. def decode_base64_to_image(encoding):
  65. if encoding.startswith("http://") or encoding.startswith("https://"):
  66. if not opts.api_enable_requests:
  67. raise HTTPException(status_code=500, detail="Requests not allowed")
  68. if opts.api_forbid_local_requests and not verify_url(encoding):
  69. raise HTTPException(status_code=500, detail="Request to local resource not allowed")
  70. headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
  71. response = requests.get(encoding, timeout=30, headers=headers)
  72. try:
  73. image = images.read(BytesIO(response.content))
  74. return image
  75. except Exception as e:
  76. raise HTTPException(status_code=500, detail="Invalid image url") from e
  77. if encoding.startswith("data:image/"):
  78. encoding = encoding.split(";")[1].split(",")[1]
  79. try:
  80. image = images.read(BytesIO(base64.b64decode(encoding)))
  81. return image
  82. except Exception as e:
  83. raise HTTPException(status_code=500, detail="Invalid encoded image") from e
  84. def encode_pil_to_base64(image):
  85. with io.BytesIO() as output_bytes:
  86. if isinstance(image, str):
  87. return image
  88. if opts.samples_format.lower() == 'png':
  89. use_metadata = False
  90. metadata = PngImagePlugin.PngInfo()
  91. for key, value in image.info.items():
  92. if isinstance(key, str) and isinstance(value, str):
  93. metadata.add_text(key, value)
  94. use_metadata = True
  95. image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)
  96. elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"):
  97. if image.mode == "RGBA":
  98. image = image.convert("RGB")
  99. parameters = image.info.get('parameters', None)
  100. exif_bytes = piexif.dump({
  101. "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") }
  102. })
  103. if opts.samples_format.lower() in ("jpg", "jpeg"):
  104. image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality)
  105. else:
  106. image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality)
  107. else:
  108. raise HTTPException(status_code=500, detail="Invalid image format")
  109. bytes_data = output_bytes.getvalue()
  110. return base64.b64encode(bytes_data)
  111. def api_middleware(app: FastAPI):
  112. rich_available = False
  113. try:
  114. if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:
  115. import anyio # importing just so it can be placed on silent list
  116. import starlette # importing just so it can be placed on silent list
  117. from rich.console import Console
  118. console = Console()
  119. rich_available = True
  120. except Exception:
  121. pass
  122. @app.middleware("http")
  123. async def log_and_time(req: Request, call_next):
  124. ts = time.time()
  125. res: Response = await call_next(req)
  126. duration = str(round(time.time() - ts, 4))
  127. res.headers["X-Process-Time"] = duration
  128. endpoint = req.scope.get('path', 'err')
  129. if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):
  130. print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(
  131. t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"),
  132. code=res.status_code,
  133. ver=req.scope.get('http_version', '0.0'),
  134. cli=req.scope.get('client', ('0:0.0.0', 0))[0],
  135. prot=req.scope.get('scheme', 'err'),
  136. method=req.scope.get('method', 'err'),
  137. endpoint=endpoint,
  138. duration=duration,
  139. ))
  140. return res
  141. def handle_exception(request: Request, e: Exception):
  142. err = {
  143. "error": type(e).__name__,
  144. "detail": vars(e).get('detail', ''),
  145. "body": vars(e).get('body', ''),
  146. "errors": str(e),
  147. }
  148. if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
  149. message = f"API error: {request.method}: {request.url} {err}"
  150. if rich_available:
  151. print(message)
  152. console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
  153. else:
  154. errors.report(message, exc_info=True)
  155. return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))
  156. @app.middleware("http")
  157. async def exception_handling(request: Request, call_next):
  158. try:
  159. return await call_next(request)
  160. except Exception as e:
  161. return handle_exception(request, e)
  162. @app.exception_handler(Exception)
  163. async def fastapi_exception_handler(request: Request, e: Exception):
  164. return handle_exception(request, e)
  165. @app.exception_handler(HTTPException)
  166. async def http_exception_handler(request: Request, e: HTTPException):
  167. return handle_exception(request, e)
  168. class Api:
  169. def __init__(self, app: FastAPI, queue_lock: Lock):
  170. if shared.cmd_opts.api_auth:
  171. self.credentials = {}
  172. for auth in shared.cmd_opts.api_auth.split(","):
  173. user, password = auth.split(":")
  174. self.credentials[user] = password
  175. self.router = APIRouter()
  176. self.app = app
  177. self.queue_lock = queue_lock
  178. api_middleware(self.app)
  179. self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
  180. self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
  181. self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
  182. self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
  183. self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
  184. self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
  185. self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
  186. self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
  187. self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
  188. self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
  189. self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
  190. self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
  191. self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
  192. self.add_api_route("/sdapi/v1/schedulers", self.get_schedulers, methods=["GET"], response_model=list[models.SchedulerItem])
  193. self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
  194. self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
  195. self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
  196. self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
  197. self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
  198. self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
  199. self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
  200. self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
  201. self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
  202. self.add_api_route("/sdapi/v1/refresh-embeddings", self.refresh_embeddings, methods=["POST"])
  203. self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
  204. self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
  205. self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
  206. self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
  207. self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
  208. self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
  209. self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
  210. self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
  211. self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
  212. self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
  213. self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
  214. self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
  215. if shared.cmd_opts.api_server_stop:
  216. self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
  217. self.add_api_route("/sdapi/v1/server-restart", self.restart_webui, methods=["POST"])
  218. self.add_api_route("/sdapi/v1/server-stop", self.stop_webui, methods=["POST"])
  219. self.default_script_arg_txt2img = []
  220. self.default_script_arg_img2img = []
  221. txt2img_script_runner = scripts.scripts_txt2img
  222. img2img_script_runner = scripts.scripts_img2img
  223. if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:
  224. ui.create_ui()
  225. if not txt2img_script_runner.scripts:
  226. txt2img_script_runner.initialize_scripts(False)
  227. if not self.default_script_arg_txt2img:
  228. self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)
  229. if not img2img_script_runner.scripts:
  230. img2img_script_runner.initialize_scripts(True)
  231. if not self.default_script_arg_img2img:
  232. self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)
  233. def add_api_route(self, path: str, endpoint, **kwargs):
  234. if shared.cmd_opts.api_auth:
  235. return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)
  236. return self.app.add_api_route(path, endpoint, **kwargs)
  237. def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):
  238. if credentials.username in self.credentials:
  239. if compare_digest(credentials.password, self.credentials[credentials.username]):
  240. return True
  241. raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
  242. def get_selectable_script(self, script_name, script_runner):
  243. if script_name is None or script_name == "":
  244. return None, None
  245. script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
  246. script = script_runner.selectable_scripts[script_idx]
  247. return script, script_idx
  248. def get_scripts_list(self):
  249. t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
  250. i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
  251. return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
  252. def get_script_info(self):
  253. res = []
  254. for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
  255. res += [script.api_info for script in script_list if script.api_info is not None]
  256. return res
  257. def get_script(self, script_name, script_runner):
  258. if script_name is None or script_name == "":
  259. return None, None
  260. script_idx = script_name_to_index(script_name, script_runner.scripts)
  261. return script_runner.scripts[script_idx]
  262. def init_default_script_args(self, script_runner):
  263. #find max idx from the scripts in runner and generate a none array to init script_args
  264. last_arg_index = 1
  265. for script in script_runner.scripts:
  266. if last_arg_index < script.args_to:
  267. last_arg_index = script.args_to
  268. # None everywhere except position 0 to initialize script args
  269. script_args = [None]*last_arg_index
  270. script_args[0] = 0
  271. # get default values
  272. with gr.Blocks(): # will throw errors calling ui function without this
  273. for script in script_runner.scripts:
  274. if script.ui(script.is_img2img):
  275. ui_default_values = []
  276. for elem in script.ui(script.is_img2img):
  277. ui_default_values.append(elem.value)
  278. script_args[script.args_from:script.args_to] = ui_default_values
  279. return script_args
  280. def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):
  281. script_args = default_script_args.copy()
  282. if input_script_args is not None:
  283. for index, value in input_script_args.items():
  284. script_args[index] = value
  285. # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()
  286. if selectable_scripts:
  287. script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args
  288. script_args[0] = selectable_idx + 1
  289. # Now check for always on scripts
  290. if request.alwayson_scripts:
  291. for alwayson_script_name in request.alwayson_scripts.keys():
  292. alwayson_script = self.get_script(alwayson_script_name, script_runner)
  293. if alwayson_script is None:
  294. raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
  295. # Selectable script in always on script param check
  296. if alwayson_script.alwayson is False:
  297. raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
  298. # always on script with no arg should always run so you don't really need to add them to the requests
  299. if "args" in request.alwayson_scripts[alwayson_script_name]:
  300. # min between arg length in scriptrunner and arg length in the request
  301. for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
  302. script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
  303. return script_args
  304. def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):
  305. """Processes `infotext` field from the `request`, and sets other fields of the `request` according to what's in infotext.
  306. If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.
  307. Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.
  308. """
  309. if not request.infotext:
  310. return {}
  311. possible_fields = infotext_utils.paste_fields[tabname]["fields"]
  312. set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have different names for this
  313. params = infotext_utils.parse_generation_parameters(request.infotext)
  314. def get_field_value(field, params):
  315. value = field.function(params) if field.function else params.get(field.label)
  316. if value is None:
  317. return None
  318. if field.api in request.__fields__:
  319. target_type = request.__fields__[field.api].type_
  320. else:
  321. target_type = type(field.component.value)
  322. if target_type == type(None):
  323. return None
  324. if isinstance(value, dict) and value.get('__type__') == 'generic_update': # this is a gradio.update rather than a value
  325. value = value.get('value')
  326. if value is not None and not isinstance(value, target_type):
  327. value = target_type(value)
  328. return value
  329. for field in possible_fields:
  330. if not field.api:
  331. continue
  332. if field.api in set_fields:
  333. continue
  334. value = get_field_value(field, params)
  335. if value is not None:
  336. setattr(request, field.api, value)
  337. if request.override_settings is None:
  338. request.override_settings = {}
  339. overridden_settings = infotext_utils.get_override_settings(params)
  340. for _, setting_name, value in overridden_settings:
  341. if setting_name not in request.override_settings:
  342. request.override_settings[setting_name] = value
  343. if script_runner is not None and mentioned_script_args is not None:
  344. indexes = {v: i for i, v in enumerate(script_runner.inputs)}
  345. script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)
  346. for field, index in script_fields:
  347. value = get_field_value(field, params)
  348. if value is None:
  349. continue
  350. mentioned_script_args[index] = value
  351. return params
  352. def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
  353. task_id = txt2imgreq.force_task_id or create_task_id("txt2img")
  354. script_runner = scripts.scripts_txt2img
  355. infotext_script_args = {}
  356. self.apply_infotext(txt2imgreq, "txt2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
  357. selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)
  358. populate = txt2imgreq.copy(update={ # Override __init__ params
  359. "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
  360. "do_not_save_samples": not txt2imgreq.save_images,
  361. "do_not_save_grid": not txt2imgreq.save_images,
  362. })
  363. if populate.sampler_name:
  364. populate.sampler_index = None # prevent a warning later on
  365. args = vars(populate)
  366. args.pop('script_name', None)
  367. args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
  368. args.pop('alwayson_scripts', None)
  369. args.pop('infotext', None)
  370. script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
  371. send_images = args.pop('send_images', True)
  372. args.pop('save_images', None)
  373. add_task_to_queue(task_id)
  374. with self.queue_lock:
  375. with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
  376. p.is_api = True
  377. p.scripts = script_runner
  378. p.outpath_grids = opts.outdir_txt2img_grids
  379. p.outpath_samples = opts.outdir_txt2img_samples
  380. try:
  381. shared.state.begin(job="scripts_txt2img")
  382. start_task(task_id)
  383. if selectable_scripts is not None:
  384. p.script_args = script_args
  385. processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
  386. else:
  387. p.script_args = tuple(script_args) # Need to pass args as tuple here
  388. processed = process_images(p)
  389. finish_task(task_id)
  390. finally:
  391. shared.state.end()
  392. shared.total_tqdm.clear()
  393. b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
  394. return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
  395. def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
  396. task_id = img2imgreq.force_task_id or create_task_id("img2img")
  397. init_images = img2imgreq.init_images
  398. if init_images is None:
  399. raise HTTPException(status_code=404, detail="Init image not found")
  400. mask = img2imgreq.mask
  401. if mask:
  402. mask = decode_base64_to_image(mask)
  403. script_runner = scripts.scripts_img2img
  404. infotext_script_args = {}
  405. self.apply_infotext(img2imgreq, "img2img", script_runner=script_runner, mentioned_script_args=infotext_script_args)
  406. selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)
  407. populate = img2imgreq.copy(update={ # Override __init__ params
  408. "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
  409. "do_not_save_samples": not img2imgreq.save_images,
  410. "do_not_save_grid": not img2imgreq.save_images,
  411. "mask": mask,
  412. })
  413. if populate.sampler_name:
  414. populate.sampler_index = None # prevent a warning later on
  415. args = vars(populate)
  416. args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
  417. args.pop('script_name', None)
  418. args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
  419. args.pop('alwayson_scripts', None)
  420. args.pop('infotext', None)
  421. script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)
  422. send_images = args.pop('send_images', True)
  423. args.pop('save_images', None)
  424. add_task_to_queue(task_id)
  425. with self.queue_lock:
  426. with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
  427. p.init_images = [decode_base64_to_image(x) for x in init_images]
  428. p.is_api = True
  429. p.scripts = script_runner
  430. p.outpath_grids = opts.outdir_img2img_grids
  431. p.outpath_samples = opts.outdir_img2img_samples
  432. try:
  433. shared.state.begin(job="scripts_img2img")
  434. start_task(task_id)
  435. if selectable_scripts is not None:
  436. p.script_args = script_args
  437. processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
  438. else:
  439. p.script_args = tuple(script_args) # Need to pass args as tuple here
  440. processed = process_images(p)
  441. finish_task(task_id)
  442. finally:
  443. shared.state.end()
  444. shared.total_tqdm.clear()
  445. b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
  446. if not img2imgreq.include_init_images:
  447. img2imgreq.init_images = None
  448. img2imgreq.mask = None
  449. return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
  450. def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
  451. reqDict = setUpscalers(req)
  452. reqDict['image'] = decode_base64_to_image(reqDict['image'])
  453. with self.queue_lock:
  454. result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
  455. return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
  456. def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
  457. reqDict = setUpscalers(req)
  458. image_list = reqDict.pop('imageList', [])
  459. image_folder = [decode_base64_to_image(x.data) for x in image_list]
  460. with self.queue_lock:
  461. result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
  462. return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
  463. def pnginfoapi(self, req: models.PNGInfoRequest):
  464. image = decode_base64_to_image(req.image.strip())
  465. if image is None:
  466. return models.PNGInfoResponse(info="")
  467. geninfo, items = images.read_info_from_image(image)
  468. if geninfo is None:
  469. geninfo = ""
  470. params = infotext_utils.parse_generation_parameters(geninfo)
  471. script_callbacks.infotext_pasted_callback(geninfo, params)
  472. return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)
  473. def progressapi(self, req: models.ProgressRequest = Depends()):
  474. # copy from check_progress_call of ui.py
  475. if shared.state.job_count == 0:
  476. return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
  477. # avoid dividing zero
  478. progress = 0.01
  479. if shared.state.job_count > 0:
  480. progress += shared.state.job_no / shared.state.job_count
  481. if shared.state.sampling_steps > 0:
  482. progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
  483. time_since_start = time.time() - shared.state.time_start
  484. eta = (time_since_start/progress)
  485. eta_relative = eta-time_since_start
  486. progress = min(progress, 1)
  487. shared.state.set_current_image()
  488. current_image = None
  489. if shared.state.current_image and not req.skip_current_image:
  490. current_image = encode_pil_to_base64(shared.state.current_image)
  491. return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)
  492. def interrogateapi(self, interrogatereq: models.InterrogateRequest):
  493. image_b64 = interrogatereq.image
  494. if image_b64 is None:
  495. raise HTTPException(status_code=404, detail="Image not found")
  496. img = decode_base64_to_image(image_b64)
  497. img = img.convert('RGB')
  498. # Override object param
  499. with self.queue_lock:
  500. if interrogatereq.model == "clip":
  501. processed = shared.interrogator.interrogate(img)
  502. elif interrogatereq.model == "deepdanbooru":
  503. processed = deepbooru.model.tag(img)
  504. else:
  505. raise HTTPException(status_code=404, detail="Model not found")
  506. return models.InterrogateResponse(caption=processed)
  507. def interruptapi(self):
  508. shared.state.interrupt()
  509. return {}
  510. def unloadapi(self):
  511. sd_models.unload_model_weights()
  512. return {}
  513. def reloadapi(self):
  514. sd_models.send_model_to_device(shared.sd_model)
  515. return {}
  516. def skip(self):
  517. shared.state.skip()
  518. def get_config(self):
  519. options = {}
  520. for key in shared.opts.data.keys():
  521. metadata = shared.opts.data_labels.get(key)
  522. if(metadata is not None):
  523. options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
  524. else:
  525. options.update({key: shared.opts.data.get(key, None)})
  526. return options
  527. def set_config(self, req: dict[str, Any]):
  528. checkpoint_name = req.get("sd_model_checkpoint", None)
  529. if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
  530. raise RuntimeError(f"model {checkpoint_name!r} not found")
  531. for k, v in req.items():
  532. shared.opts.set(k, v, is_api=True)
  533. shared.opts.save(shared.config_filename)
  534. return
  535. def get_cmd_flags(self):
  536. return vars(shared.cmd_opts)
  537. def get_samplers(self):
  538. return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
  539. def get_schedulers(self):
  540. return [
  541. {
  542. "name": scheduler.name,
  543. "label": scheduler.label,
  544. "aliases": scheduler.aliases,
  545. "default_rho": scheduler.default_rho,
  546. "need_inner_model": scheduler.need_inner_model,
  547. }
  548. for scheduler in sd_schedulers.schedulers]
  549. def get_upscalers(self):
  550. return [
  551. {
  552. "name": upscaler.name,
  553. "model_name": upscaler.scaler.model_name,
  554. "model_path": upscaler.data_path,
  555. "model_url": None,
  556. "scale": upscaler.scale,
  557. }
  558. for upscaler in shared.sd_upscalers
  559. ]
  560. def get_latent_upscale_modes(self):
  561. return [
  562. {
  563. "name": upscale_mode,
  564. }
  565. for upscale_mode in [*(shared.latent_upscale_modes or {})]
  566. ]
  567. def get_sd_models(self):
  568. import modules.sd_models as sd_models
  569. return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
  570. def get_sd_vaes(self):
  571. import modules.sd_vae as sd_vae
  572. return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]
  573. def get_hypernetworks(self):
  574. return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
  575. def get_face_restorers(self):
  576. return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
  577. def get_realesrgan_models(self):
  578. return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
  579. def get_prompt_styles(self):
  580. styleList = []
  581. for k in shared.prompt_styles.styles:
  582. style = shared.prompt_styles.styles[k]
  583. styleList.append({"name":style[0], "prompt": style[1], "negative_prompt": style[2]})
  584. return styleList
  585. def get_embeddings(self):
  586. db = sd_hijack.model_hijack.embedding_db
  587. def convert_embedding(embedding):
  588. return {
  589. "step": embedding.step,
  590. "sd_checkpoint": embedding.sd_checkpoint,
  591. "sd_checkpoint_name": embedding.sd_checkpoint_name,
  592. "shape": embedding.shape,
  593. "vectors": embedding.vectors,
  594. }
  595. def convert_embeddings(embeddings):
  596. return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
  597. return {
  598. "loaded": convert_embeddings(db.word_embeddings),
  599. "skipped": convert_embeddings(db.skipped_embeddings),
  600. }
  601. def refresh_embeddings(self):
  602. with self.queue_lock:
  603. sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
  604. def refresh_checkpoints(self):
  605. with self.queue_lock:
  606. shared.refresh_checkpoints()
  607. def refresh_vae(self):
  608. with self.queue_lock:
  609. shared_items.refresh_vae_list()
  610. def create_embedding(self, args: dict):
  611. try:
  612. shared.state.begin(job="create_embedding")
  613. filename = create_embedding(**args) # create empty embedding
  614. sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
  615. return models.CreateResponse(info=f"create embedding filename: {filename}")
  616. except AssertionError as e:
  617. return models.TrainResponse(info=f"create embedding error: {e}")
  618. finally:
  619. shared.state.end()
  620. def create_hypernetwork(self, args: dict):
  621. try:
  622. shared.state.begin(job="create_hypernetwork")
  623. filename = create_hypernetwork(**args) # create empty embedding
  624. return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
  625. except AssertionError as e:
  626. return models.TrainResponse(info=f"create hypernetwork error: {e}")
  627. finally:
  628. shared.state.end()
  629. def train_embedding(self, args: dict):
  630. try:
  631. shared.state.begin(job="train_embedding")
  632. apply_optimizations = shared.opts.training_xattention_optimizations
  633. error = None
  634. filename = ''
  635. if not apply_optimizations:
  636. sd_hijack.undo_optimizations()
  637. try:
  638. embedding, filename = train_embedding(**args) # can take a long time to complete
  639. except Exception as e:
  640. error = e
  641. finally:
  642. if not apply_optimizations:
  643. sd_hijack.apply_optimizations()
  644. return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
  645. except Exception as msg:
  646. return models.TrainResponse(info=f"train embedding error: {msg}")
  647. finally:
  648. shared.state.end()
  649. def train_hypernetwork(self, args: dict):
  650. try:
  651. shared.state.begin(job="train_hypernetwork")
  652. shared.loaded_hypernetworks = []
  653. apply_optimizations = shared.opts.training_xattention_optimizations
  654. error = None
  655. filename = ''
  656. if not apply_optimizations:
  657. sd_hijack.undo_optimizations()
  658. try:
  659. hypernetwork, filename = train_hypernetwork(**args)
  660. except Exception as e:
  661. error = e
  662. finally:
  663. shared.sd_model.cond_stage_model.to(devices.device)
  664. shared.sd_model.first_stage_model.to(devices.device)
  665. if not apply_optimizations:
  666. sd_hijack.apply_optimizations()
  667. shared.state.end()
  668. return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
  669. except Exception as exc:
  670. return models.TrainResponse(info=f"train embedding error: {exc}")
  671. finally:
  672. shared.state.end()
  673. def get_memory(self):
  674. try:
  675. import os
  676. import psutil
  677. process = psutil.Process(os.getpid())
  678. res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
  679. ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
  680. ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }
  681. except Exception as err:
  682. ram = { 'error': f'{err}' }
  683. try:
  684. import torch
  685. if torch.cuda.is_available():
  686. s = torch.cuda.mem_get_info()
  687. system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }
  688. s = dict(torch.cuda.memory_stats(shared.device))
  689. allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }
  690. reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }
  691. active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }
  692. inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }
  693. warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }
  694. cuda = {
  695. 'system': system,
  696. 'active': active,
  697. 'allocated': allocated,
  698. 'reserved': reserved,
  699. 'inactive': inactive,
  700. 'events': warnings,
  701. }
  702. else:
  703. cuda = {'error': 'unavailable'}
  704. except Exception as err:
  705. cuda = {'error': f'{err}'}
  706. return models.MemoryResponse(ram=ram, cuda=cuda)
  707. def get_extensions_list(self):
  708. from modules import extensions
  709. extensions.list_extensions()
  710. ext_list = []
  711. for ext in extensions.extensions:
  712. ext: extensions.Extension
  713. ext.read_info_from_repo()
  714. if ext.remote is not None:
  715. ext_list.append({
  716. "name": ext.name,
  717. "remote": ext.remote,
  718. "branch": ext.branch,
  719. "commit_hash":ext.commit_hash,
  720. "commit_date":ext.commit_date,
  721. "version":ext.version,
  722. "enabled":ext.enabled
  723. })
  724. return ext_list
  725. def launch(self, server_name, port, root_path):
  726. self.app.include_router(self.router)
  727. uvicorn.run(
  728. self.app,
  729. host=server_name,
  730. port=port,
  731. timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,
  732. root_path=root_path,
  733. ssl_keyfile=shared.cmd_opts.tls_keyfile,
  734. ssl_certfile=shared.cmd_opts.tls_certfile
  735. )
  736. def kill_webui(self):
  737. restart.stop_program()
  738. def restart_webui(self):
  739. if restart.is_restartable():
  740. restart.restart_program()
  741. return Response(status_code=501)
  742. def stop_webui(request):
  743. shared.state.server_command = "stop"
  744. return Response("Stopping.")