Skip to content

Commit

Permalink
make dask client fixture in pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Aug 6, 2024
1 parent 0b99fd2 commit 0a33797
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
14 changes: 11 additions & 3 deletions tests/test_rf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@
from botocore.config import Config
from datacube.utils.dask import start_local_dask

client = start_local_dask(n_workers=1, threads_per_worker=2)

project_root = Path(__file__).parents[1]
data_dir = f"{project_root}/tests/data/"


@pytest.fixture(scope="module")
def dask_client():
client = start_local_dask(n_workers=1, threads_per_worker=2)
yield client
client.close()


@pytest.fixture(scope="module")
def cultivated_model_path():
s3_bucket = "dea-public-data-dev"
Expand Down Expand Up @@ -399,14 +405,15 @@ def test_cultivated_predict(
mask_bands,
cultivated_classes,
input_datasets,
dask_client,
):
cultivated = StatsCultivatedClass(
cultivated_classes,
cultivated_model_path,
mask_bands,
input_bands=cultivated_input_bands,
)
client.register_plugin(cultivated.dask_worker_plugin)
dask_client.register_plugin(cultivated.dask_worker_plugin)
imgs = cultivated.preprocess_predict_input(input_datasets)
res = [cultivated.predict(img).compute() for img in imgs]
assert (
Expand Down Expand Up @@ -440,14 +447,15 @@ def test_cultivated_reduce(
mask_bands,
cultivated_classes,
input_datasets,
dask_client,
):
cultivated = StatsCultivatedClass(
cultivated_classes,
cultivated_model_path,
mask_bands,
input_bands=cultivated_input_bands,
)
client.register_plugin(cultivated.dask_worker_plugin)
dask_client.register_plugin(cultivated.dask_worker_plugin)
res = cultivated.reduce(input_datasets)
assert res["cultivated_class"].attrs["nodata"] == 255
assert (
Expand Down
11 changes: 9 additions & 2 deletions tests/test_urban_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
data_dir = f"{project_root}/tests/data/"


@pytest.fixture(scope="module")
def dask_client():
client = start_local_dask(n_workers=1, threads_per_worker=2)
yield client
client.close()


@pytest.fixture(scope="module")
def tflite_model_path():
s3_bucket = "dea-public-data-dev"
Expand Down Expand Up @@ -137,11 +144,11 @@ def test_impute_missing_values(output_classes, tflite_model_path, image_groups):
assert (res[1][:3, :3, :] == image_groups["ga_ls8"][:3, :3, :]).all()


def test_urban_class(output_classes, tflite_model_path, image_groups):
def test_urban_class(output_classes, tflite_model_path, image_groups, dask_client):
# test better than random for a prediction
# check correctness in integration test
stats_urban = StatsUrbanClass(output_classes, tflite_model_path)
client.register_plugin(stats_urban.dask_worker_plugin)
dask_client.register_plugin(stats_urban.dask_worker_plugin)
input_img = stats_urban.impute_missing_values_from_group(image_groups)
input_img[0][1, 1, :] = np.nan
input_img[1][1, 1, :] = np.nan
Expand Down

0 comments on commit 0a33797

Please sign in to comment.