Skip to content

Commit 66dd4c0

Browse files
committed
v3 nodes (part a)
1 parent 3dfefc8 commit 66dd4c0

File tree

6 files changed

+476
-451
lines changed

6 files changed

+476
-451
lines changed

comfy_extras/nodes_ace.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,63 @@
11
import torch
2+
from typing_extensions import override
3+
24
import comfy.model_management
35
import node_helpers
6+
from comfy_api.latest import ComfyExtension, io
7+
8+
9+
class TextEncodeAceStepAudio(io.ComfyNode):
10+
@classmethod
11+
def define_schema(cls):
12+
return io.Schema(
13+
node_id="TextEncodeAceStepAudio",
14+
category="conditioning",
15+
inputs=[
16+
io.Clip.Input("clip"),
17+
io.String.Input("tags", multiline=True, dynamic_prompts=True),
18+
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
19+
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
20+
],
21+
outputs=[io.Conditioning.Output()],
22+
)
423

5-
class TextEncodeAceStepAudio:
624
@classmethod
7-
def INPUT_TYPES(s):
8-
return {"required": {
9-
"clip": ("CLIP", ),
10-
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
11-
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
12-
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
13-
}}
14-
RETURN_TYPES = ("CONDITIONING",)
15-
FUNCTION = "encode"
16-
17-
CATEGORY = "conditioning"
18-
19-
def encode(self, clip, tags, lyrics, lyrics_strength):
25+
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
2026
tokens = clip.tokenize(tags, lyrics=lyrics)
2127
conditioning = clip.encode_from_tokens_scheduled(tokens)
2228
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
23-
return (conditioning, )
29+
return io.NodeOutput(conditioning)
2430

2531

26-
class EmptyAceStepLatentAudio:
27-
def __init__(self):
28-
self.device = comfy.model_management.intermediate_device()
29-
32+
class EmptyAceStepLatentAudio(io.ComfyNode):
3033
@classmethod
31-
def INPUT_TYPES(s):
32-
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
33-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
34-
}}
35-
RETURN_TYPES = ("LATENT",)
36-
FUNCTION = "generate"
37-
38-
CATEGORY = "latent/audio"
34+
def define_schema(cls):
35+
return io.Schema(
36+
node_id="EmptyAceStepLatentAudio",
37+
category="latent/audio",
38+
inputs=[
39+
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
40+
io.Int.Input(
41+
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
42+
),
43+
],
44+
outputs=[io.Latent.Output()],
45+
)
3946

40-
def generate(self, seconds, batch_size):
47+
@classmethod
48+
def execute(cls, seconds, batch_size) -> io.NodeOutput:
4149
length = int(seconds * 44100 / 512 / 8)
42-
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
43-
return ({"samples": latent, "type": "audio"}, )
50+
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device())
51+
return io.NodeOutput({"samples": latent, "type": "audio"})
52+
4453

54+
class AceExtension(ComfyExtension):
55+
@override
56+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
57+
return [
58+
TextEncodeAceStepAudio,
59+
EmptyAceStepLatentAudio,
60+
]
4561

46-
NODE_CLASS_MAPPINGS = {
47-
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
48-
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
49-
}
62+
async def comfy_entrypoint() -> AceExtension:
63+
return AceExtension()
Lines changed: 51 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
import comfy.samplers
2-
import comfy.utils
3-
import torch
41
import numpy as np
2+
import torch
53
from tqdm.auto import trange
4+
from typing_extensions import override
5+
6+
import comfy.model_patcher
7+
import comfy.samplers
8+
import comfy.utils
9+
from comfy.k_diffusion.sampling import to_d
10+
from comfy_api.latest import ComfyExtension, io
611

712

813
@torch.no_grad()
@@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable
3338
return x
3439

3540

36-
class SamplerLCMUpscale:
37-
upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
41+
class SamplerLCMUpscale(io.ComfyNode):
42+
UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
43+
44+
@classmethod
45+
def define_schema(cls) -> io.Schema:
46+
return io.Schema(
47+
node_id="SamplerLCMUpscale",
48+
category="sampling/custom_sampling/samplers",
49+
inputs=[
50+
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01),
51+
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1),
52+
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
53+
],
54+
outputs=[io.Sampler.Output()],
55+
)
3856

3957
@classmethod
40-
def INPUT_TYPES(s):
41-
return {"required":
42-
{"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}),
43-
"scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}),
44-
"upscale_method": (s.upscale_methods,),
45-
}
46-
}
47-
RETURN_TYPES = ("SAMPLER",)
48-
CATEGORY = "sampling/custom_sampling/samplers"
49-
50-
FUNCTION = "get_sampler"
51-
52-
def get_sampler(self, scale_ratio, scale_steps, upscale_method):
58+
def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput:
5359
if scale_steps < 0:
5460
scale_steps = None
5561
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
56-
return (sampler, )
62+
return io.NodeOutput(sampler)
5763

58-
from comfy.k_diffusion.sampling import to_d
59-
import comfy.model_patcher
6064

6165
@torch.no_grad()
6266
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
@@ -76,36 +80,42 @@ def post_cfg_function(args):
7680
denoised = model(x, sigma_hat * s_in, **extra_args)
7781
d = to_d(x - denoised + temp[0], sigmas[i], denoised)
7882
if callback is not None:
79-
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
83+
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigma_hat, "denoised": denoised})
8084
dt = sigmas[i + 1] - sigma_hat
8185
x = x + d * dt
8286
return x
8387

8488

85-
class SamplerEulerCFGpp:
89+
class SamplerEulerCFGpp(io.ComfyNode):
8690
@classmethod
87-
def INPUT_TYPES(s):
88-
return {"required":
89-
{"version": (["regular", "alternative"],),}
90-
}
91-
RETURN_TYPES = ("SAMPLER",)
92-
# CATEGORY = "sampling/custom_sampling/samplers"
93-
CATEGORY = "_for_testing"
94-
95-
FUNCTION = "get_sampler"
91+
def define_schema(cls) -> io.Schema:
92+
return io.Schema(
93+
node_id="SamplerEulerCFGpp",
94+
display_name="SamplerEulerCFG++",
95+
category="_for_testing", # "sampling/custom_sampling/samplers"
96+
inputs=[
97+
io.Combo.Input("version", options=["regular", "alternative"]),
98+
],
99+
outputs=[io.Sampler.Output()],
100+
is_experimental=True,
101+
)
96102

97-
def get_sampler(self, version):
103+
@classmethod
104+
def execute(cls, version) -> io.NodeOutput:
98105
if version == "alternative":
99106
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
100107
else:
101108
sampler = comfy.samplers.ksampler("euler_cfg_pp")
102-
return (sampler, )
109+
return io.NodeOutput(sampler)
110+
103111

104-
NODE_CLASS_MAPPINGS = {
105-
"SamplerLCMUpscale": SamplerLCMUpscale,
106-
"SamplerEulerCFGpp": SamplerEulerCFGpp,
107-
}
112+
class AdvancedSamplersExtension(ComfyExtension):
113+
@override
114+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
115+
return [
116+
SamplerLCMUpscale,
117+
SamplerEulerCFGpp,
118+
]
108119

109-
NODE_DISPLAY_NAME_MAPPINGS = {
110-
"SamplerEulerCFGpp": "SamplerEulerCFG++",
111-
}
120+
async def comfy_entrypoint() -> AdvancedSamplersExtension:
121+
return AdvancedSamplersExtension()
Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
2+
23
import numpy as np
34
import torch
5+
from typing_extensions import override
6+
7+
from comfy_api.latest import ComfyExtension, io
8+
49

510
def loglinear_interp(t_steps, num_steps):
611
"""
@@ -19,25 +24,26 @@ def loglinear_interp(t_steps, num_steps):
1924
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
2025
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
2126

22-
class AlignYourStepsScheduler:
27+
class AlignYourStepsScheduler(io.ComfyNode):
2328
@classmethod
24-
def INPUT_TYPES(s):
25-
return {"required":
26-
{"model_type": (["SD1", "SDXL", "SVD"], ),
27-
"steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
28-
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
29-
}
30-
}
31-
RETURN_TYPES = ("SIGMAS",)
32-
CATEGORY = "sampling/custom_sampling/schedulers"
33-
34-
FUNCTION = "get_sigmas"
35-
36-
def get_sigmas(self, model_type, steps, denoise):
29+
def define_schema(cls) -> io.Schema:
30+
return io.Schema(
31+
node_id="AlignYourStepsScheduler",
32+
category="sampling/custom_sampling/schedulers",
33+
inputs=[
34+
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
35+
io.Int.Input("steps", default=10, min=1, max=10000),
36+
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
37+
],
38+
outputs=[io.Sigmas.Output()],
39+
)
40+
41+
@classmethod
42+
def execute(cls, model_type, steps, denoise) -> io.NodeOutput:
3743
total_steps = steps
3844
if denoise < 1.0:
3945
if denoise <= 0.0:
40-
return (torch.FloatTensor([]),)
46+
return io.NodeOutput(torch.FloatTensor([]))
4147
total_steps = round(steps * denoise)
4248

4349
sigmas = NOISE_LEVELS[model_type][:]
@@ -46,8 +52,15 @@ def get_sigmas(self, model_type, steps, denoise):
4652

4753
sigmas = sigmas[-(total_steps + 1):]
4854
sigmas[-1] = 0
49-
return (torch.FloatTensor(sigmas), )
55+
return io.NodeOutput(torch.FloatTensor(sigmas))
56+
57+
58+
class AlignYourStepsExtension(ComfyExtension):
59+
@override
60+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
61+
return [
62+
AlignYourStepsScheduler,
63+
]
5064

51-
NODE_CLASS_MAPPINGS = {
52-
"AlignYourStepsScheduler": AlignYourStepsScheduler,
53-
}
65+
async def comfy_entrypoint() -> AlignYourStepsExtension:
66+
return AlignYourStepsExtension()

comfy_extras/nodes_apg.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,54 @@
11
import torch
2+
from typing_extensions import override
3+
4+
from comfy_api.latest import ComfyExtension, io
5+
26

37
def project(v0, v1):
48
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
59
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
610
v0_orthogonal = v0 - v0_parallel
711
return v0_parallel, v0_orthogonal
812

9-
class APG:
13+
class APG(io.ComfyNode):
1014
@classmethod
11-
def INPUT_TYPES(s):
12-
return {
13-
"required": {
14-
"model": ("MODEL",),
15-
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
16-
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
17-
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
18-
}
19-
}
20-
RETURN_TYPES = ("MODEL",)
21-
FUNCTION = "patch"
22-
CATEGORY = "sampling/custom_sampling"
23-
24-
def patch(self, model, eta, norm_threshold, momentum):
15+
def define_schema(cls) -> io.Schema:
16+
return io.Schema(
17+
node_id="APG",
18+
display_name="Adaptive Projected Guidance",
19+
category="sampling/custom_sampling",
20+
inputs=[
21+
io.Model.Input("model"),
22+
io.Float.Input(
23+
"eta",
24+
default=1.0,
25+
min=-10.0,
26+
max=10.0,
27+
step=0.01,
28+
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
29+
),
30+
io.Float.Input(
31+
"norm_threshold",
32+
default=5.0,
33+
min=0.0,
34+
max=50.0,
35+
step=0.1,
36+
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
37+
),
38+
io.Float.Input(
39+
"momentum",
40+
default=0.0,
41+
min=-5.0,
42+
max=1.0,
43+
step=0.01,
44+
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
45+
),
46+
],
47+
outputs=[io.Model.Output()],
48+
)
49+
50+
@classmethod
51+
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
2552
running_avg = 0
2653
prev_sigma = None
2754

@@ -65,12 +92,15 @@ def pre_cfg_function(args):
6592

6693
m = model.clone()
6794
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
68-
return (m,)
95+
return io.NodeOutput(m)
96+
6997

70-
NODE_CLASS_MAPPINGS = {
71-
"APG": APG,
72-
}
98+
class ApgExtension(ComfyExtension):
99+
@override
100+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
101+
return [
102+
APG,
103+
]
73104

74-
NODE_DISPLAY_NAME_MAPPINGS = {
75-
"APG": "Adaptive Projected Guidance",
76-
}
105+
async def comfy_entrypoint() -> ApgExtension:
106+
return ApgExtension()

0 commit comments

Comments
 (0)