Skip to content

Commit bab08f4

Browse files
authored
v3 nodes (part a) (#9149)
1 parent bc49106 commit bab08f4

File tree

4 files changed

+247
-163
lines changed

4 files changed

+247
-163
lines changed

comfy_extras/nodes_ace.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,63 @@
11
import torch
2+
from typing_extensions import override
3+
24
import comfy.model_management
35
import node_helpers
6+
from comfy_api.latest import ComfyExtension, io
7+
8+
9+
class TextEncodeAceStepAudio(io.ComfyNode):
10+
@classmethod
11+
def define_schema(cls):
12+
return io.Schema(
13+
node_id="TextEncodeAceStepAudio",
14+
category="conditioning",
15+
inputs=[
16+
io.Clip.Input("clip"),
17+
io.String.Input("tags", multiline=True, dynamic_prompts=True),
18+
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
19+
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
20+
],
21+
outputs=[io.Conditioning.Output()],
22+
)
423

5-
class TextEncodeAceStepAudio:
624
@classmethod
7-
def INPUT_TYPES(s):
8-
return {"required": {
9-
"clip": ("CLIP", ),
10-
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
11-
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
12-
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
13-
}}
14-
RETURN_TYPES = ("CONDITIONING",)
15-
FUNCTION = "encode"
16-
17-
CATEGORY = "conditioning"
18-
19-
def encode(self, clip, tags, lyrics, lyrics_strength):
25+
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
2026
tokens = clip.tokenize(tags, lyrics=lyrics)
2127
conditioning = clip.encode_from_tokens_scheduled(tokens)
2228
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
23-
return (conditioning, )
29+
return io.NodeOutput(conditioning)
2430

2531

26-
class EmptyAceStepLatentAudio:
27-
def __init__(self):
28-
self.device = comfy.model_management.intermediate_device()
29-
32+
class EmptyAceStepLatentAudio(io.ComfyNode):
3033
@classmethod
31-
def INPUT_TYPES(s):
32-
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
33-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
34-
}}
35-
RETURN_TYPES = ("LATENT",)
36-
FUNCTION = "generate"
37-
38-
CATEGORY = "latent/audio"
34+
def define_schema(cls):
35+
return io.Schema(
36+
node_id="EmptyAceStepLatentAudio",
37+
category="latent/audio",
38+
inputs=[
39+
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
40+
io.Int.Input(
41+
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
42+
),
43+
],
44+
outputs=[io.Latent.Output()],
45+
)
3946

40-
def generate(self, seconds, batch_size):
47+
@classmethod
48+
def execute(cls, seconds, batch_size) -> io.NodeOutput:
4149
length = int(seconds * 44100 / 512 / 8)
42-
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
43-
return ({"samples": latent, "type": "audio"}, )
50+
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device())
51+
return io.NodeOutput({"samples": latent, "type": "audio"})
52+
4453

54+
class AceExtension(ComfyExtension):
55+
@override
56+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
57+
return [
58+
TextEncodeAceStepAudio,
59+
EmptyAceStepLatentAudio,
60+
]
4561

46-
NODE_CLASS_MAPPINGS = {
47-
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
48-
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
49-
}
62+
async def comfy_entrypoint() -> AceExtension:
63+
return AceExtension()

comfy_extras/nodes_advanced_samplers.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
import comfy.samplers
2-
import comfy.utils
3-
import torch
41
import numpy as np
2+
import torch
53
from tqdm.auto import trange
4+
from typing_extensions import override
5+
6+
import comfy.model_patcher
7+
import comfy.samplers
8+
import comfy.utils
9+
from comfy.k_diffusion.sampling import to_d
10+
from comfy_api.latest import ComfyExtension, io
611

712

813
@torch.no_grad()
@@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable
3338
return x
3439

3540

36-
class SamplerLCMUpscale:
37-
upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
41+
class SamplerLCMUpscale(io.ComfyNode):
42+
UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
43+
44+
@classmethod
45+
def define_schema(cls) -> io.Schema:
46+
return io.Schema(
47+
node_id="SamplerLCMUpscale",
48+
category="sampling/custom_sampling/samplers",
49+
inputs=[
50+
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01),
51+
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1),
52+
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
53+
],
54+
outputs=[io.Sampler.Output()],
55+
)
3856

3957
@classmethod
40-
def INPUT_TYPES(s):
41-
return {"required":
42-
{"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}),
43-
"scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}),
44-
"upscale_method": (s.upscale_methods,),
45-
}
46-
}
47-
RETURN_TYPES = ("SAMPLER",)
48-
CATEGORY = "sampling/custom_sampling/samplers"
49-
50-
FUNCTION = "get_sampler"
51-
52-
def get_sampler(self, scale_ratio, scale_steps, upscale_method):
58+
def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput:
5359
if scale_steps < 0:
5460
scale_steps = None
5561
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
56-
return (sampler, )
62+
return io.NodeOutput(sampler)
5763

58-
from comfy.k_diffusion.sampling import to_d
59-
import comfy.model_patcher
6064

6165
@torch.no_grad()
6266
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
@@ -82,30 +86,36 @@ def post_cfg_function(args):
8286
return x
8387

8488

85-
class SamplerEulerCFGpp:
89+
class SamplerEulerCFGpp(io.ComfyNode):
8690
@classmethod
87-
def INPUT_TYPES(s):
88-
return {"required":
89-
{"version": (["regular", "alternative"],),}
90-
}
91-
RETURN_TYPES = ("SAMPLER",)
92-
# CATEGORY = "sampling/custom_sampling/samplers"
93-
CATEGORY = "_for_testing"
94-
95-
FUNCTION = "get_sampler"
91+
def define_schema(cls) -> io.Schema:
92+
return io.Schema(
93+
node_id="SamplerEulerCFGpp",
94+
display_name="SamplerEulerCFG++",
95+
category="_for_testing", # "sampling/custom_sampling/samplers"
96+
inputs=[
97+
io.Combo.Input("version", options=["regular", "alternative"]),
98+
],
99+
outputs=[io.Sampler.Output()],
100+
is_experimental=True,
101+
)
96102

97-
def get_sampler(self, version):
103+
@classmethod
104+
def execute(cls, version) -> io.NodeOutput:
98105
if version == "alternative":
99106
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
100107
else:
101108
sampler = comfy.samplers.ksampler("euler_cfg_pp")
102-
return (sampler, )
109+
return io.NodeOutput(sampler)
110+
103111

104-
NODE_CLASS_MAPPINGS = {
105-
"SamplerLCMUpscale": SamplerLCMUpscale,
106-
"SamplerEulerCFGpp": SamplerEulerCFGpp,
107-
}
112+
class AdvancedSamplersExtension(ComfyExtension):
113+
@override
114+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
115+
return [
116+
SamplerLCMUpscale,
117+
SamplerEulerCFGpp,
118+
]
108119

109-
NODE_DISPLAY_NAME_MAPPINGS = {
110-
"SamplerEulerCFGpp": "SamplerEulerCFG++",
111-
}
120+
async def comfy_entrypoint() -> AdvancedSamplersExtension:
121+
return AdvancedSamplersExtension()

comfy_extras/nodes_apg.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,54 @@
11
import torch
2+
from typing_extensions import override
3+
4+
from comfy_api.latest import ComfyExtension, io
5+
26

37
def project(v0, v1):
48
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
59
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
610
v0_orthogonal = v0 - v0_parallel
711
return v0_parallel, v0_orthogonal
812

9-
class APG:
13+
class APG(io.ComfyNode):
1014
@classmethod
11-
def INPUT_TYPES(s):
12-
return {
13-
"required": {
14-
"model": ("MODEL",),
15-
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
16-
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
17-
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
18-
}
19-
}
20-
RETURN_TYPES = ("MODEL",)
21-
FUNCTION = "patch"
22-
CATEGORY = "sampling/custom_sampling"
23-
24-
def patch(self, model, eta, norm_threshold, momentum):
15+
def define_schema(cls) -> io.Schema:
16+
return io.Schema(
17+
node_id="APG",
18+
display_name="Adaptive Projected Guidance",
19+
category="sampling/custom_sampling",
20+
inputs=[
21+
io.Model.Input("model"),
22+
io.Float.Input(
23+
"eta",
24+
default=1.0,
25+
min=-10.0,
26+
max=10.0,
27+
step=0.01,
28+
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
29+
),
30+
io.Float.Input(
31+
"norm_threshold",
32+
default=5.0,
33+
min=0.0,
34+
max=50.0,
35+
step=0.1,
36+
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
37+
),
38+
io.Float.Input(
39+
"momentum",
40+
default=0.0,
41+
min=-5.0,
42+
max=1.0,
43+
step=0.01,
44+
tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.",
45+
),
46+
],
47+
outputs=[io.Model.Output()],
48+
)
49+
50+
@classmethod
51+
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
2552
running_avg = 0
2653
prev_sigma = None
2754

@@ -65,12 +92,15 @@ def pre_cfg_function(args):
6592

6693
m = model.clone()
6794
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
68-
return (m,)
95+
return io.NodeOutput(m)
96+
6997

70-
NODE_CLASS_MAPPINGS = {
71-
"APG": APG,
72-
}
98+
class ApgExtension(ComfyExtension):
99+
@override
100+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
101+
return [
102+
APG,
103+
]
73104

74-
NODE_DISPLAY_NAME_MAPPINGS = {
75-
"APG": "Adaptive Projected Guidance",
76-
}
105+
async def comfy_entrypoint() -> ApgExtension:
106+
return ApgExtension()

0 commit comments

Comments
 (0)