|
@@ -9,6 +9,7 @@ import shlex
|
|
import modules.scripts as scripts
|
|
import modules.scripts as scripts
|
|
import gradio as gr
|
|
import gradio as gr
|
|
|
|
|
|
|
|
+from modules import sd_samplers
|
|
from modules.processing import Processed, process_images
|
|
from modules.processing import Processed, process_images
|
|
from PIL import Image
|
|
from PIL import Image
|
|
from modules.shared import opts, cmd_opts, state
|
|
from modules.shared import opts, cmd_opts, state
|
|
@@ -44,6 +45,7 @@ prompt_tags = {
|
|
"seed_resize_from_h": process_int_tag,
|
|
"seed_resize_from_h": process_int_tag,
|
|
"seed_resize_from_w": process_int_tag,
|
|
"seed_resize_from_w": process_int_tag,
|
|
"sampler_index": process_int_tag,
|
|
"sampler_index": process_int_tag,
|
|
|
|
+ "sampler_name": process_string_tag,
|
|
"batch_size": process_int_tag,
|
|
"batch_size": process_int_tag,
|
|
"n_iter": process_int_tag,
|
|
"n_iter": process_int_tag,
|
|
"steps": process_int_tag,
|
|
"steps": process_int_tag,
|
|
@@ -66,14 +68,28 @@ def cmdargs(line):
|
|
arg = args[pos]
|
|
arg = args[pos]
|
|
|
|
|
|
assert arg.startswith("--"), f'must start with "--": {arg}'
|
|
assert arg.startswith("--"), f'must start with "--": {arg}'
|
|
|
|
+ assert pos+1 < len(args), f'missing argument for command line option {arg}'
|
|
|
|
+
|
|
tag = arg[2:]
|
|
tag = arg[2:]
|
|
|
|
|
|
|
|
+ if tag == "prompt" or tag == "negative_prompt":
|
|
|
|
+ pos += 1
|
|
|
|
+ prompt = args[pos]
|
|
|
|
+ pos += 1
|
|
|
|
+ while pos < len(args) and not args[pos].startswith("--"):
|
|
|
|
+ prompt += " "
|
|
|
|
+ prompt += args[pos]
|
|
|
|
+ pos += 1
|
|
|
|
+ res[tag] = prompt
|
|
|
|
+ continue
|
|
|
|
+
|
|
|
|
+
|
|
func = prompt_tags.get(tag, None)
|
|
func = prompt_tags.get(tag, None)
|
|
assert func, f'unknown commandline option: {arg}'
|
|
assert func, f'unknown commandline option: {arg}'
|
|
|
|
|
|
- assert pos+1 < len(args), f'missing argument for command line option {arg}'
|
|
|
|
-
|
|
|
|
val = args[pos+1]
|
|
val = args[pos+1]
|
|
|
|
+ if tag == "sampler_name":
|
|
|
|
+ val = sd_samplers.samplers_map.get(val.lower(), None)
|
|
|
|
|
|
res[tag] = func(val)
|
|
res[tag] = func(val)
|
|
|
|
|