From 0a33797a32ef96984a062cf756a265245cda3e11 Mon Sep 17 00:00:00 2001 From: Emma Ai Date: Tue, 6 Aug 2024 10:41:33 +0000 Subject: [PATCH] make dask client fixture in pytest --- tests/test_rf_models.py | 14 +++++++++++--- tests/test_urban_model.py | 11 +++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/test_rf_models.py b/tests/test_rf_models.py index 259fd41..e10f65b 100644 --- a/tests/test_rf_models.py +++ b/tests/test_rf_models.py @@ -12,12 +12,18 @@ from botocore.config import Config from datacube.utils.dask import start_local_dask -client = start_local_dask(n_workers=1, threads_per_worker=2) project_root = Path(__file__).parents[1] data_dir = f"{project_root}/tests/data/" +@pytest.fixture(scope="module") +def dask_client(): + client = start_local_dask(n_workers=1, threads_per_worker=2) + yield client + client.close() + + @pytest.fixture(scope="module") def cultivated_model_path(): s3_bucket = "dea-public-data-dev" @@ -399,6 +405,7 @@ def test_cultivated_predict( mask_bands, cultivated_classes, input_datasets, + dask_client, ): cultivated = StatsCultivatedClass( cultivated_classes, @@ -406,7 +413,7 @@ def test_cultivated_predict( mask_bands, input_bands=cultivated_input_bands, ) - client.register_plugin(cultivated.dask_worker_plugin) + dask_client.register_plugin(cultivated.dask_worker_plugin) imgs = cultivated.preprocess_predict_input(input_datasets) res = [cultivated.predict(img).compute() for img in imgs] assert ( @@ -440,6 +447,7 @@ def test_cultivated_reduce( mask_bands, cultivated_classes, input_datasets, + dask_client, ): cultivated = StatsCultivatedClass( cultivated_classes, @@ -447,7 +455,7 @@ def test_cultivated_reduce( mask_bands, input_bands=cultivated_input_bands, ) - client.register_plugin(cultivated.dask_worker_plugin) + dask_client.register_plugin(cultivated.dask_worker_plugin) res = cultivated.reduce(input_datasets) assert res["cultivated_class"].attrs["nodata"] == 255 assert ( diff --git a/tests/test_urban_model.py b/tests/test_urban_model.py index 496ec74..c81f28c 100644 --- a/tests/test_urban_model.py +++ b/tests/test_urban_model.py @@ -15,6 +15,13 @@ data_dir = f"{project_root}/tests/data/" +@pytest.fixture(scope="module") +def dask_client(): + client = start_local_dask(n_workers=1, threads_per_worker=2) + yield client + client.close() + + @pytest.fixture(scope="module") def tflite_model_path(): s3_bucket = "dea-public-data-dev" @@ -137,11 +144,11 @@ def test_impute_missing_values(output_classes, tflite_model_path, image_groups): assert (res[1][:3, :3, :] == image_groups["ga_ls8"][:3, :3, :]).all() -def test_urban_class(output_classes, tflite_model_path, image_groups): +def test_urban_class(output_classes, tflite_model_path, image_groups, dask_client): # test better than random for a prediction # check correctness in integration test stats_urban = StatsUrbanClass(output_classes, tflite_model_path) - client.register_plugin(stats_urban.dask_worker_plugin) + dask_client.register_plugin(stats_urban.dask_worker_plugin) input_img = stats_urban.impute_missing_values_from_group(image_groups) input_img[0][1, 1, :] = np.nan input_img[1][1, 1, :] = np.nan