Parcourir la source

add images.read to automatically fix all jpeg/png weirdness

AUTOMATIC1111 il y a 1 an
Parent
commit
09b5ce68a9

+ 2 - 4
modules/api/api.py

@@ -85,8 +85,7 @@ def decode_base64_to_image(encoding):
         headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
         headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}
         response = requests.get(encoding, timeout=30, headers=headers)
         response = requests.get(encoding, timeout=30, headers=headers)
         try:
         try:
-            image = Image.open(BytesIO(response.content))
-            image = images.apply_exif_orientation(image)
+            image = images.read(BytesIO(response.content))
             return image
             return image
         except Exception as e:
         except Exception as e:
             raise HTTPException(status_code=500, detail="Invalid image url") from e
             raise HTTPException(status_code=500, detail="Invalid image url") from e
@@ -94,8 +93,7 @@ def decode_base64_to_image(encoding):
     if encoding.startswith("data:image/"):
     if encoding.startswith("data:image/"):
         encoding = encoding.split(";")[1].split(",")[1]
         encoding = encoding.split(";")[1].split(",")[1]
     try:
     try:
-        image = Image.open(BytesIO(base64.b64decode(encoding)))
-        image = images.apply_exif_orientation(image)
+        image = images.read(BytesIO(base64.b64decode(encoding)))
         return image
         return image
     except Exception as e:
     except Exception as e:
         raise HTTPException(status_code=500, detail="Invalid encoded image") from e
         raise HTTPException(status_code=500, detail="Invalid encoded image") from e

+ 18 - 46
modules/images.py

@@ -12,7 +12,7 @@ import re
 import numpy as np
 import numpy as np
 import piexif
 import piexif
 import piexif.helper
 import piexif.helper
-from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin
+from PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps
 import string
 import string
 import json
 import json
 import hashlib
 import hashlib
@@ -551,12 +551,6 @@ def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_p
         else:
         else:
             pnginfo_data = None
             pnginfo_data = None
 
 
-        # Error handling for unsupported transparency in RGB mode
-        if (image.mode == "RGB" and
-            "transparency" in image.info and
-            isinstance(image.info["transparency"], bytes)):
-            del image.info["transparency"]
-
         image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
         image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
 
 
     elif extension.lower() in (".jpg", ".jpeg", ".webp"):
     elif extension.lower() in (".jpg", ".jpeg", ".webp"):
@@ -779,7 +773,7 @@ def image_data(data):
     import gradio as gr
     import gradio as gr
 
 
     try:
     try:
-        image = Image.open(io.BytesIO(data))
+        image = read(io.BytesIO(data))
         textinfo, _ = read_info_from_image(image)
         textinfo, _ = read_info_from_image(image)
         return textinfo, None
         return textinfo, None
     except Exception:
     except Exception:
@@ -807,51 +801,29 @@ def flatten(img, bgcolor):
     return img.convert('RGB')
     return img.convert('RGB')
 
 
 
 
-# https://www.exiv2.org/tags.html
-_EXIF_ORIENT = 274  # exif 'Orientation' tag
-
-def apply_exif_orientation(image):
-    """
-    Applies the exif orientation correctly.
-
-    This code exists per the bug:
-      https://github.com/python-pillow/Pillow/issues/3973
-    with the function `ImageOps.exif_transpose`. The Pillow source raises errors with
-    various methods, especially `tobytes`
+def read(fp, **kwargs):
+    image = Image.open(fp, **kwargs)
+    image = fix_image(image)
 
 
-    Function based on:
-      https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59
-      https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527
+    return image
 
 
-    Args:
-        image (PIL.Image): a PIL image
 
 
-    Returns:
-        (PIL.Image): the PIL image with exif orientation applied, if applicable
-    """
-    if not hasattr(image, "getexif"):
-        return image
+def fix_image(image: Image.Image):
+    if image is None:
+        return None
 
 
     try:
     try:
-        exif = image.getexif()
-    except Exception:  # https://github.com/facebookresearch/detectron2/issues/1885
-        exif = None
+        image = ImageOps.exif_transpose(image)
+        image = fix_png_transparency(image)
+    except Exception:
+        pass
 
 
-    if exif is None:
-        return image
+    return image
 
 
-    orientation = exif.get(_EXIF_ORIENT)
 
 
-    method = {
-        2: Image.FLIP_LEFT_RIGHT,
-        3: Image.ROTATE_180,
-        4: Image.FLIP_TOP_BOTTOM,
-        5: Image.TRANSPOSE,
-        6: Image.ROTATE_270,
-        7: Image.TRANSVERSE,
-        8: Image.ROTATE_90,
-    }.get(orientation)
+def fix_png_transparency(image: Image.Image):
+    if image.mode not in ("RGB", "P") or not isinstance(image.info.get("transparency"), bytes):
+        return image
 
 
-    if method is not None:
-        return image.transpose(method)
+    image = image.convert("RGBA")
     return image
     return image

+ 12 - 13
modules/img2img.py

@@ -6,7 +6,7 @@ import numpy as np
 from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
 from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
 import gradio as gr
 import gradio as gr
 
 
-from modules import images as imgutil
+from modules import images
 from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
 from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters
 from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
 from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
 from modules.shared import opts, state
 from modules.shared import opts, state
@@ -21,7 +21,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
     output_dir = output_dir.strip()
     output_dir = output_dir.strip()
     processing.fix_seed(p)
     processing.fix_seed(p)
 
 
-    images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
+    batch_images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
 
 
     is_inpaint_batch = False
     is_inpaint_batch = False
     if inpaint_mask_dir:
     if inpaint_mask_dir:
@@ -31,9 +31,9 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
         if is_inpaint_batch:
         if is_inpaint_batch:
             print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
             print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")
 
 
-    print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
+    print(f"Will process {len(batch_images)} images, creating {p.n_iter * p.batch_size} new images for each.")
 
 
-    state.job_count = len(images) * p.n_iter
+    state.job_count = len(batch_images) * p.n_iter
 
 
     # extract "default" params to use in case getting png info fails
     # extract "default" params to use in case getting png info fails
     prompt = p.prompt
     prompt = p.prompt
@@ -46,8 +46,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
     sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
     sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
     batch_results = None
     batch_results = None
     discard_further_results = False
     discard_further_results = False
-    for i, image in enumerate(images):
-        state.job = f"{i+1} out of {len(images)}"
+    for i, image in enumerate(batch_images):
+        state.job = f"{i+1} out of {len(batch_images)}"
         if state.skipped:
         if state.skipped:
             state.skipped = False
             state.skipped = False
 
 
@@ -55,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
             break
             break
 
 
         try:
         try:
-            img = Image.open(image)
+            img = images.read(image)
         except UnidentifiedImageError as e:
         except UnidentifiedImageError as e:
             print(e)
             print(e)
             continue
             continue
@@ -86,7 +86,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
                 # otherwise user has many masks with the same name but different extensions
                 # otherwise user has many masks with the same name but different extensions
                 mask_image_path = masks_found[0]
                 mask_image_path = masks_found[0]
 
 
-            mask_image = Image.open(mask_image_path)
+            mask_image = images.read(mask_image_path)
             p.image_mask = mask_image
             p.image_mask = mask_image
 
 
         if use_png_info:
         if use_png_info:
@@ -94,8 +94,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
                 info_img = img
                 info_img = img
                 if png_info_dir:
                 if png_info_dir:
                     info_img_path = os.path.join(png_info_dir, os.path.basename(image))
                     info_img_path = os.path.join(png_info_dir, os.path.basename(image))
-                    info_img = Image.open(info_img_path)
-                geninfo, _ = imgutil.read_info_from_image(info_img)
+                    info_img = images.read(info_img_path)
+                geninfo, _ = images.read_info_from_image(info_img)
                 parsed_parameters = parse_generation_parameters(geninfo)
                 parsed_parameters = parse_generation_parameters(geninfo)
                 parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
                 parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
             except Exception:
             except Exception:
@@ -175,9 +175,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
         image = None
         image = None
         mask = None
         mask = None
 
 
-    # Use the EXIF orientation of photos taken by smartphones.
-    if image is not None:
-        image = ImageOps.exif_transpose(image)
+    image = images.fix_image(image)
+    mask = images.fix_image(mask)
 
 
     if selected_scale_tab == 1 and not is_batch:
     if selected_scale_tab == 1 and not is_batch:
         assert image, "Can't scale by because no image is selected"
         assert image, "Can't scale by because no image is selected"

+ 3 - 3
modules/infotext_utils.py

@@ -8,7 +8,7 @@ import sys
 
 
 import gradio as gr
 import gradio as gr
 from modules.paths import data_path
 from modules.paths import data_path
-from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions
+from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images
 from PIL import Image
 from PIL import Image
 
 
 sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__]  # alias for old name
 sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__]  # alias for old name
@@ -83,7 +83,7 @@ def image_from_url_text(filedata):
         assert is_in_right_dir, 'trying to open image file outside of allowed directories'
         assert is_in_right_dir, 'trying to open image file outside of allowed directories'
 
 
         filename = filename.rsplit('?', 1)[0]
         filename = filename.rsplit('?', 1)[0]
-        return Image.open(filename)
+        return images.read(filename)
 
 
     if type(filedata) == list:
     if type(filedata) == list:
         if len(filedata) == 0:
         if len(filedata) == 0:
@@ -95,7 +95,7 @@ def image_from_url_text(filedata):
         filedata = filedata[len("data:image/png;base64,"):]
         filedata = filedata[len("data:image/png;base64,"):]
 
 
     filedata = base64.decodebytes(filedata.encode('utf-8'))
     filedata = base64.decodebytes(filedata.encode('utf-8'))
-    image = Image.open(io.BytesIO(filedata))
+    image = images.read(io.BytesIO(filedata))
     return image
     return image
 
 
 
 

+ 3 - 3
modules/postprocessing.py

@@ -17,10 +17,10 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
         if extras_mode == 1:
         if extras_mode == 1:
             for img in image_folder:
             for img in image_folder:
                 if isinstance(img, Image.Image):
                 if isinstance(img, Image.Image):
-                    image = img
+                    image = images.fix_image(img)
                     fn = ''
                     fn = ''
                 else:
                 else:
-                    image = Image.open(os.path.abspath(img.name))
+                    image = images.read(os.path.abspath(img.name))
                     fn = os.path.splitext(img.orig_name)[0]
                     fn = os.path.splitext(img.orig_name)[0]
                 yield image, fn
                 yield image, fn
         elif extras_mode == 2:
         elif extras_mode == 2:
@@ -56,7 +56,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
 
 
         if isinstance(image_placeholder, str):
         if isinstance(image_placeholder, str):
             try:
             try:
-                image_data = Image.open(image_placeholder)
+                image_data = images.read(image_placeholder)
             except Exception:
             except Exception:
                 continue
                 continue
         else:
         else:

+ 2 - 2
modules/textual_inversion/dataset.py

@@ -10,7 +10,7 @@ from random import shuffle, choices
 
 
 import random
 import random
 import tqdm
 import tqdm
-from modules import devices, shared
+from modules import devices, shared, images
 import re
 import re
 
 
 from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
 from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
@@ -61,7 +61,7 @@ class PersonalizedBase(Dataset):
             if shared.state.interrupted:
             if shared.state.interrupted:
                 raise Exception("interrupted")
                 raise Exception("interrupted")
             try:
             try:
-                image = Image.open(path)
+                image = images.read(path)
                 #Currently does not work for single color transparency
                 #Currently does not work for single color transparency
                 #We would need to read image.info['transparency'] for that
                 #We would need to read image.info['transparency'] for that
                 if use_weight and 'A' in image.getbands():
                 if use_weight and 'A' in image.getbands():