gradio_extensons.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import gradio as gr
  2. from modules import scripts, ui_tempdir, patches
  3. def add_classes_to_gradio_component(comp):
  4. """
  5. this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
  6. """
  7. comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
  8. if getattr(comp, 'multiselect', False):
  9. comp.elem_classes.append('multiselect')
  10. def IOComponent_init(self, *args, **kwargs):
  11. self.webui_tooltip = kwargs.pop('tooltip', None)
  12. if scripts.scripts_current is not None:
  13. scripts.scripts_current.before_component(self, **kwargs)
  14. scripts.script_callbacks.before_component_callback(self, **kwargs)
  15. res = original_IOComponent_init(self, *args, **kwargs)
  16. add_classes_to_gradio_component(self)
  17. scripts.script_callbacks.after_component_callback(self, **kwargs)
  18. if scripts.scripts_current is not None:
  19. scripts.scripts_current.after_component(self, **kwargs)
  20. return res
  21. def Block_get_config(self):
  22. config = original_Block_get_config(self)
  23. webui_tooltip = getattr(self, 'webui_tooltip', None)
  24. if webui_tooltip:
  25. config["webui_tooltip"] = webui_tooltip
  26. config.pop('example_inputs', None)
  27. return config
  28. def BlockContext_init(self, *args, **kwargs):
  29. if scripts.scripts_current is not None:
  30. scripts.scripts_current.before_component(self, **kwargs)
  31. scripts.script_callbacks.before_component_callback(self, **kwargs)
  32. res = original_BlockContext_init(self, *args, **kwargs)
  33. add_classes_to_gradio_component(self)
  34. scripts.script_callbacks.after_component_callback(self, **kwargs)
  35. if scripts.scripts_current is not None:
  36. scripts.scripts_current.after_component(self, **kwargs)
  37. return res
  38. def Blocks_get_config_file(self, *args, **kwargs):
  39. config = original_Blocks_get_config_file(self, *args, **kwargs)
  40. for comp_config in config["components"]:
  41. if "example_inputs" in comp_config:
  42. comp_config["example_inputs"] = {"serialized": []}
  43. return config
  44. original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init)
  45. original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config)
  46. original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init)
  47. original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file)
  48. ui_tempdir.install_ui_tempdir_override()