Эх сурвалжийг харах

Fix some deprecated types

a666 2 жил өмнө
parent
commit
b6c1a1bbbf

+ 13 - 13
modules/api/api.py

@@ -29,7 +29,7 @@ from modules.sd_models import unload_model_weights, reload_model_weights, checkp
 from modules.sd_models_config import find_checkpoint_config_near_filename
 from modules.realesrgan_model import get_realesrgan_models
 from modules import devices
-from typing import Dict, List, Any
+from typing import Any
 import piexif
 import piexif.helper
 from contextlib import closing
@@ -221,15 +221,15 @@ class Api:
         self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
         self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
         self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
-        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
-        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
-        self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem])
-        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
-        self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem])
-        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
-        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
-        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
-        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
+        self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=list[models.SamplerItem])
+        self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=list[models.UpscalerItem])
+        self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=list[models.LatentUpscalerModeItem])
+        self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=list[models.SDModelItem])
+        self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=list[models.SDVaeItem])
+        self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=list[models.HypernetworkItem])
+        self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=list[models.FaceRestorerItem])
+        self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=list[models.RealesrganItem])
+        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=list[models.PromptStyleItem])
         self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
         self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
         self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
@@ -242,8 +242,8 @@ class Api:
         self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
         self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
         self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
-        self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
-        self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=List[models.ExtensionItem])
+        self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=list[models.ScriptInfo])
+        self.add_api_route("/sdapi/v1/extensions", self.get_extensions_list, methods=["GET"], response_model=list[models.ExtensionItem])
 
         if shared.cmd_opts.api_server_stop:
             self.add_api_route("/sdapi/v1/server-kill", self.kill_webui, methods=["POST"])
@@ -563,7 +563,7 @@ class Api:
 
         return options
 
-    def set_config(self, req: Dict[str, Any]):
+    def set_config(self, req: dict[str, Any]):
         checkpoint_name = req.get("sd_model_checkpoint", None)
         if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
             raise RuntimeError(f"model {checkpoint_name!r} not found")

+ 11 - 13
modules/api/models.py

@@ -1,12 +1,10 @@
 import inspect
 
 from pydantic import BaseModel, Field, create_model
-from typing import Any, Optional
-from typing_extensions import Literal
+from typing import Any, Optional, Literal
 from inflection import underscore
 from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
 from modules.shared import sd_upscalers, opts, parser
-from typing import Dict, List
 
 API_NOT_ALLOWED = [
     "self",
@@ -130,12 +128,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
 ).generate_model()
 
 class TextToImageResponse(BaseModel):
-    images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+    images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
     parameters: dict
     info: str
 
 class ImageToImageResponse(BaseModel):
-    images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
+    images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
     parameters: dict
     info: str
 
@@ -168,10 +166,10 @@ class FileData(BaseModel):
     name: str = Field(title="File name")
 
 class ExtrasBatchImagesRequest(ExtrasBaseRequest):
-    imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
+    imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
 
 class ExtrasBatchImagesResponse(ExtraBaseResponse):
-    images: List[str] = Field(title="Images", description="The generated images in base64 format.")
+    images: list[str] = Field(title="Images", description="The generated images in base64 format.")
 
 class PNGInfoRequest(BaseModel):
     image: str = Field(title="Image", description="The base64 encoded PNG image")
@@ -233,8 +231,8 @@ FlagsModel = create_model("Flags", **flags)
 
 class SamplerItem(BaseModel):
     name: str = Field(title="Name")
-    aliases: List[str] = Field(title="Aliases")
-    options: Dict[str, str] = Field(title="Options")
+    aliases: list[str] = Field(title="Aliases")
+    options: dict[str, str] = Field(title="Options")
 
 class UpscalerItem(BaseModel):
     name: str = Field(title="Name")
@@ -285,8 +283,8 @@ class EmbeddingItem(BaseModel):
     vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")
 
 class EmbeddingsResponse(BaseModel):
-    loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
-    skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
+    loaded: dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
+    skipped: dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
 
 class MemoryResponse(BaseModel):
     ram: dict = Field(title="RAM", description="System memory stats")
@@ -304,14 +302,14 @@ class ScriptArg(BaseModel):
     minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
     maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
     step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
-    choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
+    choices: Optional[list[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
 
 
 class ScriptInfo(BaseModel):
     name: str = Field(default=None, title="Name", description="Script name")
     is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
     is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
-    args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
+    args: list[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
 
 class ExtensionItem(BaseModel):
     name: str = Field(title="Name", description="Extension name")

+ 1 - 1
modules/gitpython_hack.py

@@ -23,7 +23,7 @@ class Git(git.Git):
         )
         return self._parse_object_header(ret)
 
-    def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]:
+    def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:
         # Not really streaming, per se; this buffers the entire object in memory.
         # Shouldn't be a problem for our use case, since we're only using this for
         # object headers (commit objects).

+ 3 - 4
modules/prompt_parser.py

@@ -2,7 +2,6 @@ from __future__ import annotations
 
 import re
 from collections import namedtuple
-from typing import List
 import lark
 
 # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
@@ -240,14 +239,14 @@ def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
 
 class ComposableScheduledPromptConditioning:
     def __init__(self, schedules, weight=1.0):
-        self.schedules: List[ScheduledPromptConditioning] = schedules
+        self.schedules: list[ScheduledPromptConditioning] = schedules
         self.weight: float = weight
 
 
 class MulticondLearnedConditioning:
     def __init__(self, shape, batch):
         self.shape: tuple = shape  # the shape field is needed to send this object to DDIM/PLMS
-        self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
+        self.batch: list[list[ComposableScheduledPromptConditioning]] = batch
 
 
 def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
@@ -278,7 +277,7 @@ class DictWithShape(dict):
         return self["crossattn"].shape
 
 
-def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
+def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
     param = c[0][0].cond
     is_dict = isinstance(param, dict)
 

+ 3 - 3
modules/script_callbacks.py

@@ -1,7 +1,7 @@
 import inspect
 import os
 from collections import namedtuple
-from typing import Optional, Dict, Any
+from typing import Optional, Any
 
 from fastapi import FastAPI
 from gradio import Blocks
@@ -255,7 +255,7 @@ def image_grid_callback(params: ImageGridLoopParams):
             report_exception(c, 'image_grid')
 
 
-def infotext_pasted_callback(infotext: str, params: Dict[str, Any]):
+def infotext_pasted_callback(infotext: str, params: dict[str, Any]):
     for c in callback_map['callbacks_infotext_pasted']:
         try:
             c.callback(infotext, params)
@@ -446,7 +446,7 @@ def on_infotext_pasted(callback):
     """register a function to be called before applying an infotext.
     The callback is called with two arguments:
        - infotext: str - raw infotext.
-       - result: Dict[str, any] - parsed infotext parameters.
+       - result: dict[str, any] - parsed infotext parameters.
     """
     add_callback(callback_map['callbacks_infotext_pasted'], callback)
 

+ 2 - 2
modules/sub_quadratic_attention.py

@@ -15,7 +15,7 @@ import torch
 from torch import Tensor
 from torch.utils.checkpoint import checkpoint
 import math
-from typing import Optional, NamedTuple, List
+from typing import Optional, NamedTuple
 
 
 def narrow_trunc(
@@ -97,7 +97,7 @@ def _query_chunk_attention(
         )
         return summarize_chunk(query, key_chunk, value_chunk)
 
-    chunks: List[AttnChunk] = [
+    chunks: list[AttnChunk] = [
         chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
     ]
     acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))

+ 1 - 2
modules/ui.py

@@ -1338,7 +1338,6 @@ checkpoint: <a id="sd_checkpoint_hash">N/A</a>
 
 def setup_ui_api(app):
     from pydantic import BaseModel, Field
-    from typing import List
 
     class QuicksettingsHint(BaseModel):
         name: str = Field(title="Name of the quicksettings field")
@@ -1347,7 +1346,7 @@ def setup_ui_api(app):
     def quicksettings_hint():
         return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]
 
-    app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint])
+    app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=list[QuicksettingsHint])
 
     app.add_api_route("/internal/ping", lambda: {}, methods=["GET"])