Skip to content

Commit 4533c37

Browse files
committed
feat: Add timestep shift and two new schedulers
1 parent f6b9aa1 commit 4533c37

File tree

4 files changed

+322
-163
lines changed

4 files changed

+322
-163
lines changed

denoiser.hpp

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,25 @@ struct GITSSchedule : SigmaSchedule {
232232
}
233233
};
234234

235+
struct SGMUniformSchedule : SigmaSchedule {
236+
std::vector<float> get_sigmas(uint32_t n, float sigma_min_in, float sigma_max_in, t_to_sigma_t t_to_sigma_func) override {
237+
238+
std::vector<float> result;
239+
if (n == 0) {
240+
result.push_back(0.0f);
241+
return result;
242+
}
243+
result.reserve(n + 1);
244+
int t_max = TIMESTEPS -1;
245+
float step = static_cast<float>(t_max) / static_cast<float>(n > 1 ? (n -1) : 1) ;
246+
for(uint32_t i=0; i<n; ++i) {
247+
result.push_back(t_to_sigma_func(t_max - step * i));
248+
}
249+
result.push_back(0.0f);
250+
return result;
251+
}
252+
};
253+
235254
struct KarrasSchedule : SigmaSchedule {
236255
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
237256
// These *COULD* be function arguments here,
@@ -251,6 +270,36 @@ struct KarrasSchedule : SigmaSchedule {
251270
}
252271
};
253272

273+
struct SimpleSchedule : SigmaSchedule {
274+
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
275+
std::vector<float> result_sigmas;
276+
277+
if (n == 0) {
278+
return result_sigmas;
279+
}
280+
281+
result_sigmas.reserve(n + 1);
282+
283+
int model_sigmas_len = TIMESTEPS;
284+
285+
float step_factor = static_cast<float>(model_sigmas_len) / static_cast<float>(n);
286+
287+
for (uint32_t i = 0; i < n; ++i) {
288+
289+
int offset_from_start_of_py_array = static_cast<int>(static_cast<float>(i) * step_factor);
290+
int timestep_index = model_sigmas_len - 1 - offset_from_start_of_py_array;
291+
292+
if (timestep_index < 0) {
293+
timestep_index = 0;
294+
}
295+
296+
result_sigmas.push_back(t_to_sigma(static_cast<float>(timestep_index)));
297+
}
298+
result_sigmas.push_back(0.0f);
299+
return result_sigmas;
300+
}
301+
};
302+
254303
struct Denoiser {
255304
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
256305
virtual float sigma_min() = 0;
@@ -262,8 +311,39 @@ struct Denoiser {
262311
virtual ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) = 0;
263312

264313
virtual std::vector<float> get_sigmas(uint32_t n) {
265-
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
266-
return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
314+
// Check if the current schedule is SGMUniformSchedule
315+
if (std::dynamic_pointer_cast<SGMUniformSchedule>(schedule)) {
316+
std::vector<float> sigs;
317+
sigs.reserve(n + 1);
318+
319+
if (n == 0) {
320+
sigs.push_back(0.0f);
321+
return sigs;
322+
}
323+
324+
// Use the Denoiser's own sigma_to_t and t_to_sigma methods
325+
float start_t_val = this->sigma_to_t(this->sigma_max());
326+
float end_t_val = this->sigma_to_t(this->sigma_min());
327+
328+
float dt_per_step;
329+
if (n > 0) {
330+
dt_per_step = (end_t_val - start_t_val) / static_cast<float>(n);
331+
} else {
332+
dt_per_step = 0.0f;
333+
}
334+
335+
for (uint32_t i = 0; i < n; ++i) {
336+
float current_t = start_t_val + static_cast<float>(i) * dt_per_step;
337+
sigs.push_back(this->t_to_sigma(current_t));
338+
}
339+
340+
sigs.push_back(0.0f);
341+
return sigs;
342+
343+
} else { // For all other schedules, use the existing virtual dispatch
344+
auto bound_t_to_sigma = std::bind(&Denoiser::t_to_sigma, this, std::placeholders::_1);
345+
return schedule->get_sigmas(n, sigma_min(), sigma_max(), bound_t_to_sigma);
346+
}
267347
}
268348
};
269349

examples/cli/main.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ struct SDParams {
105105
float slg_scale = 0.f;
106106
float skip_layer_start = 0.01f;
107107
float skip_layer_end = 0.2f;
108+
int shifted_timestep = -1;
108109

109110
bool chroma_use_dit_mask = true;
110111
bool chroma_use_t5_mask = false;
@@ -163,6 +164,7 @@ void print_params(SDParams params) {
163164
printf(" batch_count: %d\n", params.batch_count);
164165
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
165166
printf(" upscale_repeats: %d\n", params.upscale_repeats);
167+
printf(" timestep_shift: %d\n", params.shifted_timestep);
166168
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
167169
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
168170
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
@@ -223,7 +225,7 @@ void print_usage(int argc, const char* argv[]) {
223225
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
224226
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
225227
printf(" -b, --batch-count COUNT number of images to generate\n");
226-
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n");
228+
printf(" --schedule {discrete, karras, exponential, ays, gits, sgm_uniform, simple} Denoiser sigma schedule (default: discrete)\n");
227229
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
228230
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
229231
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
@@ -235,6 +237,7 @@ void print_usage(int argc, const char* argv[]) {
235237
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
236238
printf(" --canny apply canny preprocessor (edge detection)\n");
237239
printf(" --color colors the logging tags according to level\n");
240+
printf(" --timestep-shift N shift timestep for NitroFusion models, default: -1 off, recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant\n");
238241
printf(" --chroma-disable-dit-mask disable dit mask for chroma\n");
239242
printf(" --chroma-enable-t5-mask enable t5 mask for chroma\n");
240243
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
@@ -487,7 +490,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
487490
const char* arg = argv[index];
488491
params.schedule = str_to_schedule(arg);
489492
if (params.schedule == SCHEDULE_COUNT) {
490-
fprintf(stderr, "error: invalid schedule %s\n",
493+
fprintf(stderr, "error: invalid schedule %s, must be one of [discrete, karras, exponential, ays, gits, sgm_uniform, simple]\n",
491494
arg);
492495
return -1;
493496
}
@@ -568,7 +571,18 @@ void parse_args(int argc, const char** argv, SDParams& params) {
568571
{"-r", "--ref-image", "", on_ref_image_arg},
569572
{"-h", "--help", "", on_help_arg},
570573
};
571-
574+
auto on_timestep_shift_arg = [&](int argc, const char** argv, int index) {
575+
if (++index >= argc) {
576+
return -1;
577+
}
578+
params.shifted_timestep = std::stoi(argv[index]);
579+
if (params.shifted_timestep != -1 && (params.shifted_timestep < 1 || params.shifted_timestep > 1000)) {
580+
fprintf(stderr, "error: timestep-shift must be between 1 and 1000, or -1 to disable\n");
581+
return -1;
582+
}
583+
return 1;
584+
};
585+
options.manual_options.push_back({"", "--timestep-shift", "", on_timestep_shift_arg});
572586
if (!parse_options(argc, argv, options)) {
573587
print_usage(argc, argv);
574588
exit(1);
@@ -979,6 +993,7 @@ int main(int argc, const char* argv[]) {
979993
params.style_ratio,
980994
params.normalize_input,
981995
params.input_id_images_path.c_str(),
996+
params.shifted_timestep,
982997
};
983998

984999
results = generate_image(sd_ctx, &img_gen_params);

0 commit comments

Comments
 (0)