Skip to content

Commit

Permalink
Final fixes (#118)
Browse files Browse the repository at this point in the history
final fixes before release
  • Loading branch information
patrickvonplaten committed Jul 21, 2022
1 parent 3b7f514 commit 5311f56
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
13 changes: 7 additions & 6 deletions src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@


class ScoreSdeVePipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
self.register_modules(unet=unet, scheduler=scheduler)

@torch.no_grad()
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):

if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

img_size = self.model.config.sample_size
img_size = self.unet.config.sample_size
shape = (batch_size, 3, img_size, img_size)

model = self.model.to(torch_device)
model = self.unet.to(torch_device)

sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = sample.to(torch_device)
Expand All @@ -31,7 +32,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch

# correction step
for _ in range(self.scheduler.correct_steps):
model_output = self.model(sample, sigma_t)["sample"]
model_output = self.unet(sample, sigma_t)["sample"]
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]

# prediction step
Expand All @@ -40,7 +41,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch

sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]

sample = sample.clamp(0, 1)
sample = sample_mean.clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
sample = self.numpy_to_pil(sample)
Expand Down
11 changes: 4 additions & 7 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,15 +848,12 @@ def test_ldm_text2img_fast(self):

@slow
def test_score_sde_ve_pipeline(self):
model = UNet2DModel.from_pretrained("google/ncsnpp-church-256")
model_id = "google/ncsnpp-church-256"
model = UNet2DModel.from_pretrained(model_id)

torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)

scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
scheduler = ScoreSdeVeScheduler.from_config(model_id)

sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)

torch.manual_seed(0)
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
Expand Down

0 comments on commit 5311f56

Please sign in to comment.