lora_script.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import re
  2. import torch
  3. import gradio as gr
  4. from fastapi import FastAPI
  5. import lora
  6. import extra_networks_lora
  7. import ui_extra_networks_lora
  8. from modules import script_callbacks, ui_extra_networks, extra_networks, shared
  9. def unload():
  10. torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
  11. torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
  12. torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
  13. torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
  14. torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
  15. torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
  16. def before_ui():
  17. ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
  18. extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
  19. if not hasattr(torch.nn, 'Linear_forward_before_lora'):
  20. torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
  21. if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
  22. torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
  23. if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
  24. torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
  25. if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
  26. torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
  27. if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
  28. torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
  29. if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
  30. torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
  31. torch.nn.Linear.forward = lora.lora_Linear_forward
  32. torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
  33. torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
  34. torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
  35. torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
  36. torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
  37. script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
  38. script_callbacks.on_script_unloaded(unload)
  39. script_callbacks.on_before_ui(before_ui)
  40. script_callbacks.on_infotext_pasted(lora.infotext_pasted)
  41. shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
  42. "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
  43. "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
  44. "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
  45. }))
  46. shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
  47. "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
  48. }))
  49. def create_lora_json(obj: lora.LoraOnDisk):
  50. return {
  51. "name": obj.name,
  52. "alias": obj.alias,
  53. "path": obj.filename,
  54. "metadata": obj.metadata,
  55. }
  56. def api_loras(_: gr.Blocks, app: FastAPI):
  57. @app.get("/sdapi/v1/loras")
  58. async def get_loras():
  59. return [create_lora_json(obj) for obj in lora.available_loras.values()]
  60. @app.post("/sdapi/v1/refresh-loras")
  61. async def refresh_loras():
  62. return lora.list_available_loras()
  63. script_callbacks.on_app_started(api_loras)
  64. re_lora = re.compile("<lora:([^:]+):")
  65. def infotext_pasted(infotext, d):
  66. hashes = d.get("Lora hashes")
  67. if not hashes:
  68. return
  69. hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
  70. hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
  71. def lora_replacement(m):
  72. alias = m.group(1)
  73. shorthash = hashes.get(alias)
  74. if shorthash is None:
  75. return m.group(0)
  76. lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)
  77. if lora_on_disk is None:
  78. return m.group(0)
  79. return f'<lora:{lora_on_disk.get_alias()}:'
  80. d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])
  81. script_callbacks.on_infotext_pasted(infotext_pasted)