File tree Expand file tree Collapse file tree 2 files changed +9
-8
lines changed Expand file tree Collapse file tree 2 files changed +9
-8
lines changed Original file line number Diff line number Diff line change @@ -660,11 +660,10 @@ def __call__(
660660 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
661661 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
662662
663- print (t )
664663 noise_pred = self .transformer (
665664 hidden_states = latent_model_input ,
666665 encoder_hidden_states = prompt_embeds ,
667- timestep = 1000 - timestep ,
666+ timestep = timestep ,
668667 encoder_attention_mask = prompt_attention_mask ,
669668 return_dict = False ,
670669 )[0 ]
Original file line number Diff line number Diff line change @@ -205,9 +205,15 @@ def set_timesteps(
205205 sigmas = torch .from_numpy (sigmas ).to (dtype = torch .float32 , device = device )
206206 timesteps = sigmas * self .config .num_train_timesteps
207207
208- self .timesteps = timesteps .to (device = device )
209- self .sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
208+ if self .config .invert_sigmas :
209+ sigmas = 1.0 - sigmas
210+ timesteps = sigmas * self .config .num_train_timesteps
211+ sigmas = torch .cat ([sigmas , torch .ones (1 , device = sigmas .device )])
212+ else :
213+ sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
210214
215+ self .timesteps = timesteps .to (device = device )
216+ self .sigmas = sigmas
211217 self ._step_index = None
212218 self ._begin_index = None
213219
@@ -295,10 +301,6 @@ def step(
295301 sigma = self .sigmas [self .step_index ]
296302 sigma_next = self .sigmas [self .step_index + 1 ]
297303
298- if self .config .invert_sigmas :
299- print ("inverting" )
300- sigma , sigma_next = sigma_next , sigma
301-
302304 prev_sample = sample + (sigma_next - sigma ) * model_output
303305
304306 # Cast sample back to model compatible dtype
You can’t perform that action at this time.
0 commit comments