소스 검색

Merge pull request #15577 from AUTOMATIC1111/api-get-schedulers

Add schedulers API endpoint
AUTOMATIC1111 1 년 전
부모
커밋
f8f5d6cea2
2개의 변경된 파일20개의 추가작업 그리고 1개의 파일을 삭제
  1. 13 1
      modules/api/api.py
  2. 7 0
      modules/api/models.py

+ 13 - 1
modules/api/api.py

@@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder
 from secrets import compare_digest
 from secrets import compare_digest
 
 
 import modules.shared as shared
 import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
 from modules.api import models
 from modules.api import models
 from modules.shared import opts
 from modules.shared import opts
 from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
 from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -221,6 +221,7 @@ class Api:
         self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
         self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
         self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
         self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
         self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
         self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
+        self.add_api_route("/sdapi/v1/schedulers", self.get_schedulers, methods=["GET"], response_model=list[models.SchedulerItem])
         self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
         self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
         self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
         self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
         self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
         self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
@@ -683,6 +684,17 @@ class Api:
     def get_samplers(self):
     def get_samplers(self):
         return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
         return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
 
 
+    def get_schedulers(self):
+        return [
+            {
+                "name": scheduler.name,
+                "label": scheduler.label,
+                "aliases": scheduler.aliases,
+                "default_rho": scheduler.default_rho,
+                "need_inner_model": scheduler.need_inner_model,
+            }
+            for scheduler in sd_schedulers.schedulers]
+
     def get_upscalers(self):
     def get_upscalers(self):
         return [
         return [
             {
             {

+ 7 - 0
modules/api/models.py

@@ -235,6 +235,13 @@ class SamplerItem(BaseModel):
     aliases: list[str] = Field(title="Aliases")
     aliases: list[str] = Field(title="Aliases")
     options: dict[str, str] = Field(title="Options")
     options: dict[str, str] = Field(title="Options")
 
 
+class SchedulerItem(BaseModel):
+    name: str = Field(title="Name")
+    label: str = Field(title="Label")
+    aliases: Optional[list[str]] = Field(title="Aliases")
+    default_rho: Optional[float] = Field(title="Default Rho")
+    need_inner_model: Optional[bool] = Field(title="Needs Inner Model")
+
 class UpscalerItem(BaseModel):
 class UpscalerItem(BaseModel):
     name: str = Field(title="Name")
     name: str = Field(title="Name")
     model_name: Optional[str] = Field(title="Model Name")
     model_name: Optional[str] = Field(title="Model Name")