image_embedding.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import base64
  2. import json
  3. import os.path
  4. import warnings
  5. import logging
  6. import numpy as np
  7. import zlib
  8. from PIL import Image, ImageDraw
  9. import torch
  10. logger = logging.getLogger(__name__)
  11. class EmbeddingEncoder(json.JSONEncoder):
  12. def default(self, obj):
  13. if isinstance(obj, torch.Tensor):
  14. return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
  15. return json.JSONEncoder.default(self, obj)
  16. class EmbeddingDecoder(json.JSONDecoder):
  17. def __init__(self, *args, **kwargs):
  18. json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)
  19. def object_hook(self, d):
  20. if 'TORCHTENSOR' in d:
  21. return torch.from_numpy(np.array(d['TORCHTENSOR']))
  22. return d
  23. def embedding_to_b64(data):
  24. d = json.dumps(data, cls=EmbeddingEncoder)
  25. return base64.b64encode(d.encode())
  26. def embedding_from_b64(data):
  27. d = base64.b64decode(data)
  28. return json.loads(d, cls=EmbeddingDecoder)
  29. def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
  30. while True:
  31. seed = (a * seed + c) % m
  32. yield seed % 255
  33. def xor_block(block):
  34. g = lcg()
  35. randblock = np.array([next(g) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape)
  36. return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
  37. def style_block(block, sequence):
  38. im = Image.new('RGB', (block.shape[1], block.shape[0]))
  39. draw = ImageDraw.Draw(im)
  40. i = 0
  41. for x in range(-6, im.size[0], 8):
  42. for yi, y in enumerate(range(-6, im.size[1], 8)):
  43. offset = 0
  44. if yi % 2 == 0:
  45. offset = 4
  46. shade = sequence[i % len(sequence)]
  47. i += 1
  48. draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
  49. fg = np.array(im).astype(np.uint8) & 0xF0
  50. return block ^ fg
  51. def insert_image_data_embed(image, data):
  52. d = 3
  53. data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
  54. data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
  55. data_np_high = data_np_ >> 4
  56. data_np_low = data_np_ & 0x0F
  57. h = image.size[1]
  58. next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
  59. next_size = next_size + ((h*d)-(next_size % (h*d)))
  60. data_np_low = np.resize(data_np_low, next_size)
  61. data_np_low = data_np_low.reshape((h, -1, d))
  62. data_np_high = np.resize(data_np_high, next_size)
  63. data_np_high = data_np_high.reshape((h, -1, d))
  64. edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
  65. edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
  66. data_np_low = style_block(data_np_low, sequence=edge_style)
  67. data_np_low = xor_block(data_np_low)
  68. data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
  69. data_np_high = xor_block(data_np_high)
  70. im_low = Image.fromarray(data_np_low, mode='RGB')
  71. im_high = Image.fromarray(data_np_high, mode='RGB')
  72. background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
  73. background.paste(im_low, (0, 0))
  74. background.paste(image, (im_low.size[0]+1, 0))
  75. background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
  76. return background
  77. def crop_black(img, tol=0):
  78. mask = (img > tol).all(2)
  79. mask0, mask1 = mask.any(0), mask.any(1)
  80. col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
  81. row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
  82. return img[row_start:row_end, col_start:col_end]
  83. def extract_image_data_embed(image):
  84. d = 3
  85. outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
  86. black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
  87. if black_cols[0].shape[0] < 2:
  88. logger.debug(f'{os.path.basename(getattr(image, "filename", "unknown image file"))}: no embedded information found.')
  89. return None
  90. data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
  91. data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
  92. data_block_lower = xor_block(data_block_lower)
  93. data_block_upper = xor_block(data_block_upper)
  94. data_block = (data_block_upper << 4) | (data_block_lower)
  95. data_block = data_block.flatten().tobytes()
  96. data = zlib.decompress(data_block)
  97. return json.loads(data, cls=EmbeddingDecoder)
  98. def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
  99. from modules.images import get_font
  100. if textfont:
  101. warnings.warn(
  102. 'passing in a textfont to caption_image_overlay is deprecated and does nothing',
  103. DeprecationWarning,
  104. stacklevel=2,
  105. )
  106. from math import cos
  107. image = srcimage.copy()
  108. fontsize = 32
  109. factor = 1.5
  110. gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
  111. for y in range(image.size[1]):
  112. mag = 1-cos(y/image.size[1]*factor)
  113. mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
  114. gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
  115. image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
  116. draw = ImageDraw.Draw(image)
  117. font = get_font(fontsize)
  118. padding = 10
  119. _, _, w, h = draw.textbbox((0, 0), title, font=font)
  120. fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
  121. font = get_font(fontsize)
  122. _, _, w, h = draw.textbbox((0, 0), title, font=font)
  123. draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
  124. _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
  125. fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
  126. _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
  127. fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
  128. _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
  129. fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
  130. font = get_font(min(fontsize_left, fontsize_mid, fontsize_right))
  131. draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
  132. draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
  133. draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
  134. return image
  135. if __name__ == '__main__':
  136. testEmbed = Image.open('test_embedding.png')
  137. data = extract_image_data_embed(testEmbed)
  138. assert data is not None
  139. data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
  140. assert data is not None
  141. image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
  142. cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
  143. test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
  144. embedded_image = insert_image_data_embed(cap_image, test_embed)
  145. retrieved_embed = extract_image_data_embed(embedded_image)
  146. assert str(retrieved_embed) == str(test_embed)
  147. embedded_image2 = insert_image_data_embed(cap_image, retrieved_embed)
  148. assert embedded_image == embedded_image2
  149. g = lcg()
  150. shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
  151. reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
  152. 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
  153. 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
  154. 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
  155. 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
  156. 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
  157. 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
  158. 204, 86, 73, 222, 44, 198, 118, 240, 97]
  159. assert shared_random == reference_random
  160. hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
  161. assert 12731374 == hunna_kay_random_sum