diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 0ce9626bd364..939ac9ec27bb 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -6,19 +6,20 @@ class ScoreSdeVePipeline(DiffusionPipeline): - def __init__(self, model, scheduler): + def __init__(self, unet, scheduler): super().__init__() - self.register_modules(model=model, scheduler=scheduler) + self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"): + if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" - img_size = self.model.config.sample_size + img_size = self.unet.config.sample_size shape = (batch_size, 3, img_size, img_size) - model = self.model.to(torch_device) + model = self.unet.to(torch_device) sample = torch.randn(*shape) * self.scheduler.config.sigma_max sample = sample.to(torch_device) @@ -31,7 +32,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch # correction step for _ in range(self.scheduler.correct_steps): - model_output = self.model(sample, sigma_t)["sample"] + model_output = self.unet(sample, sigma_t)["sample"] sample = self.scheduler.step_correct(model_output, sample)["prev_sample"] # prediction step @@ -40,7 +41,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] - sample = sample.clamp(0, 1) + sample = sample_mean.clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": sample = self.numpy_to_pil(sample) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 693d8932040c..ce4f9958d504 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -848,15 +848,12 @@ def test_ldm_text2img_fast(self): @slow def test_score_sde_ve_pipeline(self): - model = UNet2DModel.from_pretrained("google/ncsnpp-church-256") + model_id = "google/ncsnpp-church-256" + model = UNet2DModel.from_pretrained(model_id) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256") + scheduler = ScoreSdeVeScheduler.from_config(model_id) - sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) + sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler) torch.manual_seed(0) image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]