Browse Source

allow add middleware after app has started

this should completely fix "Cannot add middleware after an application has started" which can occur due to a race condition
w-e-w 7 months ago
parent
commit
e936dbb43b
2 changed files with 38 additions and 3 deletions
  1. 1 0
      modules/initialize.py
  2. 37 3
      modules/initialize_util.py

+ 1 - 0
modules/initialize.py

@@ -50,6 +50,7 @@ def check_versions():
 
 def initialize():
     from modules import initialize_util
+    initialize_util.allow_add_middleware_after_start()
     initialize_util.fix_torch_version()
     initialize_util.fix_pytorch_lightning()
     initialize_util.fix_asyncio_event_loop_policy()

+ 37 - 3
modules/initialize_util.py

@@ -5,6 +5,8 @@ import sys
 import re
 
 from modules.timer import startup_timer
+from modules import patches
+from functools import wraps
 
 
 def gradio_server_name():
@@ -191,11 +193,8 @@ def configure_opts_onchange():
 
 def setup_middleware(app):
     from starlette.middleware.gzip import GZipMiddleware
-
-    app.middleware_stack = None  # reset current middleware to allow modifying user provided list
     app.add_middleware(GZipMiddleware, minimum_size=1000)
     configure_cors_middleware(app)
-    app.build_middleware_stack()  # rebuild middleware stack on-the-fly
 
 
 def configure_cors_middleware(app):
@@ -213,3 +212,38 @@ def configure_cors_middleware(app):
         cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
     app.add_middleware(CORSMiddleware, **cors_options)
 
+
+def allow_add_middleware_after_start():
+    from starlette.applications import Starlette
+
+    def add_middleware_wrapper(func):
+        """Patch Starlette.add_middleware to allow for middleware to be added after the app has started
+
+        Starlette.add_middleware raises RuntimeError("Cannot add middleware after an application has started") if middleware_stack is not None.
+        We can force add new middleware by first setting middleware_stack to None, then adding the middleware.
+        When middleware_stack is None, it will rebuild the middleware_stack on the next request (Lazily build middleware stack).
+
+        If packages are updated in the future, things may break, so we have two ways to add middleware after the app has started:
+        the first way is to just set middleware_stack to None and then retry
+        the second manually insert the middleware into the user_middleware list without calling add_middleware
+        """
+
+        @wraps(func)
+        def wrapper(self, *args, **kwargs):
+            res = None
+            try:
+                res = func(self, *args, **kwargs)
+            except RuntimeError as _:
+                try:
+                    self.middleware_stack = None
+                    res = func(self, *args, **kwargs)
+                except RuntimeError as e:
+                    print(f'Warning: "{e}", Retrying...')
+                    from starlette.middleware import Middleware
+                    self.user_middleware.insert(0, Middleware(*args, **kwargs))
+                self.middleware_stack = None  # ensure middleware_stack in the event of concurrent requests
+            return res
+
+        return wrapper
+
+    patches.patch(__name__, obj=Starlette, field="add_middleware", replacement=add_middleware_wrapper(Starlette.add_middleware))