Эх сурвалжийг харах

Merge branch 'master' into cpu-cmdline-opt

brkirch 2 жил өмнө
parent
commit
e9e2a7ec9a

+ 10 - 2
modules/codeformer_model.py

@@ -69,10 +69,14 @@ def setup_model(dirname):
 
                 self.net = net
                 self.face_helper = face_helper
-                self.net.to(devices.device_codeformer)
 
                 return net, face_helper
 
+            def send_model_to(self, device):
+                self.net.to(device)
+                self.face_helper.face_det.to(device)
+                self.face_helper.face_parse.to(device)
+
             def restore(self, np_image, w=None):
                 np_image = np_image[:, :, ::-1]
 
@@ -82,6 +86,8 @@ def setup_model(dirname):
                 if self.net is None or self.face_helper is None:
                     return np_image
 
+                self.send_model_to(devices.device_codeformer)
+
                 self.face_helper.clean_all()
                 self.face_helper.read_image(np_image)
                 self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@@ -113,8 +119,10 @@ def setup_model(dirname):
                 if original_resolution != restored_img.shape[0:2]:
                     restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
 
+                self.face_helper.clean_all()
+
                 if shared.opts.face_restoration_unload:
-                    self.net.to(devices.cpu)
+                    self.send_model_to(devices.cpu)
 
                 return restored_img
 

+ 10 - 0
modules/devices.py

@@ -1,3 +1,5 @@
+import contextlib
+
 import torch
 
 from modules import errors
@@ -56,3 +58,11 @@ def randn_without_seed(shape):
 
     return torch.randn(shape, device=device)
 
+
+def autocast():
+    from modules import shared
+
+    if dtype == torch.float32 or shared.cmd_opts.precision == "full":
+        return contextlib.nullcontext()
+
+    return torch.autocast("cuda")

+ 13 - 3
modules/gfpgan_model.py

@@ -36,23 +36,33 @@ def gfpgann():
     else:
         print("Unable to load gfpgan model!")
         return None
-    model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
-    model.gfpgan.to(devices.device_gfpgan)
+    model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
     loaded_gfpgan_model = model
 
     return model
 
 
+def send_model_to(model, device):
+    model.gfpgan.to(device)
+    model.face_helper.face_det.to(device)
+    model.face_helper.face_parse.to(device)
+
+
 def gfpgan_fix_faces(np_image):
     model = gfpgann()
     if model is None:
         return np_image
+
+    send_model_to(model, devices.device_gfpgan)
+
     np_image_bgr = np_image[:, :, ::-1]
     cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
     np_image = gfpgan_output_bgr[:, :, ::-1]
 
+    model.face_helper.clean_all()
+
     if shared.opts.face_restoration_unload:
-        model.gfpgan.to(devices.cpu)
+        send_model_to(model, devices.cpu)
 
     return np_image
 

+ 11 - 7
modules/processing.py

@@ -1,4 +1,3 @@
-import contextlib
 import json
 import math
 import os
@@ -330,9 +329,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
 
     infotexts = []
     output_images = []
-    precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
-    ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
-    with torch.no_grad(), precision_scope("cuda"), ema_scope():
+
+    with torch.no_grad():
         p.init(all_prompts, all_seeds, all_subseeds)
 
         if state.job_count == -1:
@@ -351,8 +349,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
 
             #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
             #c = p.sd_model.get_learned_conditioning(prompts)
-            uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
-            c = prompt_parser.get_learned_conditioning(prompts, p.steps)
+            with devices.autocast():
+                uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
+                c = prompt_parser.get_learned_conditioning(prompts, p.steps)
 
             if len(model_hijack.comments) > 0:
                 for comment in model_hijack.comments:
@@ -361,13 +360,17 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
             if p.n_iter > 1:
                 shared.state.job = f"Batch {n+1} out of {p.n_iter}"
 
-            samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+            with devices.autocast():
+                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
+
             if state.interrupted:
 
                 # if we are interruped, sample returns just noise
                 # use the image collected previously in sampler loop
                 samples_ddim = shared.state.current_latent
 
+            samples_ddim = samples_ddim.to(devices.dtype)
+
             x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
             x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
 
@@ -386,6 +389,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
                     devices.torch_gc()
 
                     x_sample = modules.face_restoration.restore_faces(x_sample)
+                    devices.torch_gc()
 
                 image = Image.fromarray(x_sample)
 

+ 53 - 66
modules/prompt_parser.py

@@ -1,20 +1,11 @@
 import re
 from collections import namedtuple
 import torch
+from lark import Lark, Transformer, Visitor
+import functools
 
 import modules.shared as shared
 
-re_prompt = re.compile(r'''
-(.*?)
-\[
-    ([^]:]+):
-    (?:([^]:]*):)?
-    ([0-9]*\.?[0-9]+)
-]
-|
-(.+)
-''', re.X)
-
 # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"
 # will be represented with prompt_schedule like this (assuming steps=100):
 # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']
@@ -25,61 +16,57 @@ re_prompt = re.compile(r'''
 
 
 def get_learned_conditioning_prompt_schedules(prompts, steps):
-    res = []
-    cache = {}
-
-    for prompt in prompts:
-        prompt_schedule: list[list[str | int]] = [[steps, ""]]
-
-        cached = cache.get(prompt, None)
-        if cached is not None:
-            res.append(cached)
-            continue
-
-        for m in re_prompt.finditer(prompt):
-            plaintext = m.group(1) if m.group(5) is None else m.group(5)
-            concept_from = m.group(2)
-            concept_to = m.group(3)
-            if concept_to is None:
-                concept_to = concept_from
-                concept_from = ""
-            swap_position = float(m.group(4)) if m.group(4) is not None else None
-
-            if swap_position is not None:
-                if swap_position < 1:
-                    swap_position = swap_position * steps
-                swap_position = int(min(swap_position, steps))
-
-            swap_index = None
-            found_exact_index = False
-            for i in range(len(prompt_schedule)):
-                end_step = prompt_schedule[i][0]
-                prompt_schedule[i][1] += plaintext
-
-                if swap_position is not None and swap_index is None:
-                    if swap_position == end_step:
-                        swap_index = i
-                        found_exact_index = True
-
-                    if swap_position < end_step:
-                        swap_index = i
-
-            if swap_index is not None:
-                if not found_exact_index:
-                    prompt_schedule.insert(swap_index, [swap_position, prompt_schedule[swap_index][1]])
-
-                for i in range(len(prompt_schedule)):
-                    end_step = prompt_schedule[i][0]
-                    must_replace = swap_position < end_step
-
-                    prompt_schedule[i][1] += concept_to if must_replace else concept_from
-
-        res.append(prompt_schedule)
-        cache[prompt] = prompt_schedule
-        #for t in prompt_schedule:
-        #    print(t)
-
-    return res
+    grammar = r"""
+    start: prompt
+    prompt: (emphasized | scheduled | weighted | plain)*
+    !emphasized: "(" prompt ")"
+            | "(" prompt ":" prompt ")"
+            | "[" prompt "]"
+    scheduled: "[" (prompt ":")? prompt ":" NUMBER "]"
+    !weighted: "{" weighted_item ("|" weighted_item)* "}"
+    !weighted_item: prompt (":" prompt)?
+    plain: /([^\\\[\](){}:|]|\\.)+/
+    %import common.SIGNED_NUMBER -> NUMBER
+    """
+    parser = Lark(grammar, parser='lalr')
+    def collect_steps(steps, tree):
+        l = [steps]
+        class CollectSteps(Visitor):
+            def scheduled(self, tree):
+                tree.children[-1] = float(tree.children[-1])
+                if tree.children[-1] < 1:
+                    tree.children[-1] *= steps
+                tree.children[-1] = min(steps, int(tree.children[-1]))
+                l.append(tree.children[-1])
+        CollectSteps().visit(tree)
+        return sorted(set(l))
+    def at_step(step, tree):
+        class AtStep(Transformer):
+            def scheduled(self, args):
+                if len(args) == 2:
+                    before, after, when = (), *args
+                else:
+                    before, after, when = args
+                yield before if step <= when else after
+            def start(self, args):
+                def flatten(x):
+                    if type(x) == str:
+                        yield x
+                    else:
+                        for gen in x:
+                            yield from flatten(gen)
+                return ''.join(flatten(args[0]))
+            def plain(self, args):
+                yield args[0].value
+            def __default__(self, data, children, meta):
+                for child in children:
+                    yield from child
+        return AtStep().transform(tree)
+    @functools.cache
+    def get_schedule(prompt):
+        tree = parser.parse(prompt)
+        return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]
+    return [get_schedule(prompt) for prompt in prompts]
 
 
 ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])

+ 10 - 2
modules/ui.py

@@ -386,14 +386,22 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
         outputs=[seed, dummy_component]
     )
 
+
 def update_token_counter(text, steps):
-    prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
+    try:
+        prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
+    except Exception:
+        # a parsing error can happen here during typing, and we don't want to bother the user with
+        # messages related to it in console
+        prompt_schedules = [[[steps, text]]]
+
     flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
-    prompts = [prompt_text for step,prompt_text in flat_prompts]
+    prompts = [prompt_text for step, prompt_text in flat_prompts]
     tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
     style_class = ' class="red"' if (token_count > max_length) else ""
     return f"<span {style_class}>{token_count}/{max_length}</span>"
 
+
 def create_toprow(is_img2img):
     id_part = "img2img" if is_img2img else "txt2img"
 

+ 1 - 0
requirements.txt

@@ -22,3 +22,4 @@ clean-fid
 resize-right
 torchdiffeq
 kornia
+lark

+ 1 - 0
requirements_versions.txt

@@ -21,3 +21,4 @@ clean-fid==0.1.29
 resize-right==0.0.2
 torchdiffeq==0.2.3
 kornia==0.6.7
+lark==1.1.2