浏览代码

Merge pull request #15319 from catboxanon/feat/ssmd_cover_images

Support cover images embedded in safetensors metadata
AUTOMATIC1111 1 年之前
父节点
当前提交
b0b90dc0d7
共有 3 个文件被更改,包括 43 次插入2 次删除
  1. 0 1
      extensions-builtin/Lora/network.py
  2. 1 1
      extensions-builtin/Lora/ui_extra_networks_lora.py
  3. 42 0
      modules/ui_extra_networks.py

+ 0 - 1
extensions-builtin/Lora/network.py

@@ -29,7 +29,6 @@ class NetworkOnDisk:
 
         def read_metadata():
             metadata = sd_models.read_metadata_from_safetensors(filename)
-            metadata.pop('ssmd_cover_images', None)  # those are cover images, and they are too big to display in UI as text
 
             return metadata
 

+ 1 - 1
extensions-builtin/Lora/ui_extra_networks_lora.py

@@ -31,7 +31,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
             "name": name,
             "filename": lora_on_disk.filename,
             "shorthash": lora_on_disk.shorthash,
-            "preview": self.find_preview(path),
+            "preview": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata),
             "description": self.find_description(path),
             "search_terms": search_terms,
             "local_preview": f"{path}.{shared.opts.samples_format}",

+ 42 - 0
modules/ui_extra_networks.py

@@ -1,6 +1,8 @@
 import functools
 import os.path
 import urllib.parse
+from base64 import b64decode
+from io import BytesIO
 from pathlib import Path
 from typing import Optional, Union
 from dataclasses import dataclass
@@ -11,6 +13,7 @@ import gradio as gr
 import json
 import html
 from fastapi.exceptions import HTTPException
+from PIL import Image
 
 from modules.infotext_utils import image_from_url_text
 
@@ -108,6 +111,31 @@ def fetch_file(filename: str = ""):
     return FileResponse(filename, headers={"Accept-Ranges": "bytes"})
 
 
+def fetch_cover_images(page: str = "", item: str = "", index: int = 0):
+    from starlette.responses import Response
+
+    page = next(iter([x for x in extra_pages if x.name == page]), None)
+    if page is None:
+        raise HTTPException(status_code=404, detail="File not found")
+
+    metadata = page.metadata.get(item)
+    if metadata is None:
+        raise HTTPException(status_code=404, detail="File not found")
+
+    cover_images = json.loads(metadata.get('ssmd_cover_images', {}))
+    image = cover_images[index] if index < len(cover_images) else None
+    if not image:
+        raise HTTPException(status_code=404, detail="File not found")
+
+    try:
+        image = Image.open(BytesIO(b64decode(image)))
+        buffer = BytesIO()
+        image.save(buffer, format=image.format)
+        return Response(content=buffer.getvalue(), media_type=image.get_format_mimetype())
+    except Exception as err:
+        raise ValueError(f"File cannot be fetched: {item}. Failed to load cover image.") from err
+
+
 def get_metadata(page: str = "", item: str = ""):
     from starlette.responses import JSONResponse
 
@@ -119,6 +147,8 @@ def get_metadata(page: str = "", item: str = ""):
     if metadata is None:
         return JSONResponse({})
 
+    metadata = {i:metadata[i] for i in metadata if i != 'ssmd_cover_images'}  # those are cover images, and they are too big to display in UI as text
+
     return JSONResponse({"metadata": json.dumps(metadata, indent=4, ensure_ascii=False)})
 
 
@@ -142,6 +172,7 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""):
 
 def add_pages_to_demo(app):
     app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])
+    app.add_api_route("/sd_extra_networks/cover-images", fetch_cover_images, methods=["GET"])
     app.add_api_route("/sd_extra_networks/metadata", get_metadata, methods=["GET"])
     app.add_api_route("/sd_extra_networks/get-single-card", get_single_card, methods=["GET"])
 
@@ -626,6 +657,17 @@ class ExtraNetworksPage:
 
         return None
 
+    def find_embedded_preview(self, path, name, metadata):
+        """
+        Find if embedded preview exists in safetensors metadata and return endpoint for it.
+        """
+
+        file = f"{path}.safetensors"
+        if self.lister.exists(file) and 'ssmd_cover_images' in metadata and len(list(filter(None, json.loads(metadata['ssmd_cover_images'])))) > 0:
+            return f"./sd_extra_networks/cover-images?page={self.extra_networks_tabname}&item={name}"
+
+        return None
+
     def find_description(self, path):
         """
         Find and read a description file for a given path (without extension).