image_embedding.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import base64
  2. import json
  3. import numpy as np
  4. import zlib
  5. from PIL import Image, ImageDraw, ImageFont
  6. from fonts.ttf import Roboto
  7. import torch
  8. from modules.shared import opts
  9. class EmbeddingEncoder(json.JSONEncoder):
  10. def default(self, obj):
  11. if isinstance(obj, torch.Tensor):
  12. return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
  13. return json.JSONEncoder.default(self, obj)
  14. class EmbeddingDecoder(json.JSONDecoder):
  15. def __init__(self, *args, **kwargs):
  16. json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
  17. def object_hook(self, d):
  18. if 'TORCHTENSOR' in d:
  19. return torch.from_numpy(np.array(d['TORCHTENSOR']))
  20. return d
  21. def embedding_to_b64(data):
  22. d = json.dumps(data, cls=EmbeddingEncoder)
  23. return base64.b64encode(d.encode())
  24. def embedding_from_b64(data):
  25. d = base64.b64decode(data)
  26. return json.loads(d, cls=EmbeddingDecoder)
  27. def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
  28. while True:
  29. seed = (a * seed + c) % m
  30. yield seed % 255
  31. def xor_block(block):
  32. g = lcg()
  33. randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
  34. return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
  35. def style_block(block, sequence):
  36. im = Image.new('RGB', (block.shape[1], block.shape[0]))
  37. draw = ImageDraw.Draw(im)
  38. i = 0
  39. for x in range(-6, im.size[0], 8):
  40. for yi, y in enumerate(range(-6, im.size[1], 8)):
  41. offset = 0
  42. if yi % 2 == 0:
  43. offset = 4
  44. shade = sequence[i % len(sequence)]
  45. i += 1
  46. draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
  47. fg = np.array(im).astype(np.uint8) & 0xF0
  48. return block ^ fg
  49. def insert_image_data_embed(image, data):
  50. d = 3
  51. data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
  52. data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
  53. data_np_high = data_np_ >> 4
  54. data_np_low = data_np_ & 0x0F
  55. h = image.size[1]
  56. next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
  57. next_size = next_size + ((h*d)-(next_size % (h*d)))
  58. data_np_low = np.resize(data_np_low, next_size)
  59. data_np_low = data_np_low.reshape((h, -1, d))
  60. data_np_high = np.resize(data_np_high, next_size)
  61. data_np_high = data_np_high.reshape((h, -1, d))
  62. edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
  63. edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
  64. data_np_low = style_block(data_np_low, sequence=edge_style)
  65. data_np_low = xor_block(data_np_low)
  66. data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
  67. data_np_high = xor_block(data_np_high)
  68. im_low = Image.fromarray(data_np_low, mode='RGB')
  69. im_high = Image.fromarray(data_np_high, mode='RGB')
  70. background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
  71. background.paste(im_low, (0, 0))
  72. background.paste(image, (im_low.size[0]+1, 0))
  73. background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
  74. return background
  75. def crop_black(img, tol=0):
  76. mask = (img > tol).all(2)
  77. mask0, mask1 = mask.any(0), mask.any(1)
  78. col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
  79. row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
  80. return img[row_start:row_end, col_start:col_end]
  81. def extract_image_data_embed(image):
  82. d = 3
  83. outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
  84. black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
  85. if black_cols[0].shape[0] < 2:
  86. print('No Image data blocks found.')
  87. return None
  88. data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
  89. data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
  90. data_block_lower = xor_block(data_block_lower)
  91. data_block_upper = xor_block(data_block_upper)
  92. data_block = (data_block_upper << 4) | (data_block_lower)
  93. data_block = data_block.flatten().tobytes()
  94. data = zlib.decompress(data_block)
  95. return json.loads(data, cls=EmbeddingDecoder)
  96. def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
  97. from math import cos
  98. image = srcimage.copy()
  99. fontsize = 32
  100. if textfont is None:
  101. try:
  102. textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
  103. textfont = opts.font or Roboto
  104. except Exception:
  105. textfont = Roboto
  106. factor = 1.5
  107. gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
  108. for y in range(image.size[1]):
  109. mag = 1-cos(y/image.size[1]*factor)
  110. mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
  111. gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
  112. image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
  113. draw = ImageDraw.Draw(image)
  114. font = ImageFont.truetype(textfont, fontsize)
  115. padding = 10
  116. _, _, w, h = draw.textbbox((0, 0), title, font=font)
  117. fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
  118. font = ImageFont.truetype(textfont, fontsize)
  119. _, _, w, h = draw.textbbox((0, 0), title, font=font)
  120. draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
  121. _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
  122. fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
  123. _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
  124. fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
  125. _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
  126. fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
  127. font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
  128. draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
  129. draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
  130. draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
  131. return image
  132. if __name__ == '__main__':
  133. testEmbed = Image.open('test_embedding.png')
  134. data = extract_image_data_embed(testEmbed)
  135. assert data is not None
  136. data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
  137. assert data is not None
  138. image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
  139. cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
  140. test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
  141. embedded_image = insert_image_data_embed(cap_image, test_embed)
  142. retrived_embed = extract_image_data_embed(embedded_image)
  143. assert str(retrived_embed) == str(test_embed)
  144. embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
  145. assert embedded_image == embedded_image2
  146. g = lcg()
  147. shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
  148. reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
  149. 95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
  150. 160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
  151. 38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
  152. 30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
  153. 41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
  154. 66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
  155. 204, 86, 73, 222, 44, 198, 118, 240, 97]
  156. assert shared_random == reference_random
  157. hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
  158. assert 12731374 == hunna_kay_random_sum