Skip to content

Commit

Permalink
Woody cover and cultivated model plugins for landcover (#146)
Browse files Browse the repository at this point in the history
* add plugin for woody cover

* clean up loading and aggregate results conservatively

* fix memsink name of urban model

* update odc-algo head

* add cultivated model plugin for landcover

* add test for cultivated model

* update dependencies

* change docker-compose to docker compose

* make dask client fixture in pytest

* add test for woody cover

* add treelite as dependency

* revise the docstring

---------

Co-authored-by: Emma Ai <[email protected]>
  • Loading branch information
emmaai and Emma Ai committed Aug 8, 2024
1 parent f482ce7 commit db360b0
Show file tree
Hide file tree
Showing 15 changed files with 1,187 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/statistician-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
shell: bash
run: |
cd docker
docker-compose build
docker compose build
- name: Run Dockerized Tests for Statistician
shell: bash
Expand Down
1 change: 0 additions & 1 deletion docker/env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ dependencies:
- vim
- rio-cogeo
- aiobotocore
- awscliv2
- boto3
- affine
- aiohttp
Expand Down
3 changes: 2 additions & 1 deletion docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
--extra-index-url https://packages.dea.ga.gov.au/
datacube[performance,s3]>=1.8.17
hdstats==0.1.8.post1
odc-algo @ git+https://github.com/opendatacube/odc-algo@bb662fe
odc-algo @ git+https://github.com/opendatacube/odc-algo@adb1856
odc-apps-cloud>=0.2.2
# For testing
odc-apps-dc-tools>=0.2.12
Expand All @@ -15,3 +15,4 @@ odc-stats[ows]

# For ML
tflite-runtime
tl2cgen
8 changes: 4 additions & 4 deletions docker/test_statistician.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/usr/bin/env bash
set -ex

docker-compose up -d
docker compose up -d
sleep 5

docker-compose exec -T stats odc-stats --version
docker compose exec -T stats odc-stats --version
echo "Indexing some data"
docker-compose exec -e AWS_DEFAULT_REGION=ap-southeast-2 -T stats ./tests/init_db.sh
docker compose exec -e AWS_DEFAULT_REGION=ap-southeast-2 -T stats ./tests/init_db.sh
echo "Data regression test"
docker-compose exec -T stats ./tests/integration_test.sh
docker compose exec -T stats ./tests/integration_test.sh
11 changes: 8 additions & 3 deletions odc/stats/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,10 +440,15 @@ def render_assembler_metadata(
inherit_skip_properties=self.product.inherit_skip_properties,
)

if "eo:platform" in source_datasetdoc.properties:
if source_datasetdoc.properties.get("eo:platform") is not None:
platforms.append(source_datasetdoc.properties["eo:platform"])
if "eo:instrument" in source_datasetdoc.properties:
instruments.append(source_datasetdoc.properties["eo:instrument"])
if source_datasetdoc.properties.get("eo:instrument") is not None:
if isinstance(source_datasetdoc.properties["eo:instrument"], list):
instruments += source_datasetdoc.properties["eo:instrument"]
else:
instruments.append(
source_datasetdoc.properties["eo:instrument"]
)

dataset_assembler.platform = ",".join(sorted(set(platforms)))
dataset_assembler.instrument = "_".join(sorted(set(instruments)))
Expand Down
2 changes: 2 additions & 0 deletions odc/stats/plugins/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def import_all():

# TODO: make that more automatic
modules = [
"odc.stats.plugins.lc_treelite_cultivated.py",
"odc.stats.plugins.lc_treelite_woody",
"odc.stats.plugins.lc_tf_urban",
"odc.stats.plugins.lc_veg_class_a1",
"odc.stats.plugins.lc_fc_wo_a0",
Expand Down
25 changes: 25 additions & 0 deletions odc/stats/plugins/_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dask.distributed import WorkerPlugin, get_worker
import tflite_runtime.interpreter as tflite
import tl2cgen
import threading
import logging

Expand Down Expand Up @@ -33,3 +34,27 @@ def get_interpreter(self):
thread_id,
)
return worker.interpreters[thread_id]


class TreeliteModelPlugin(WorkerPlugin):
def __init__(self, model_path):
self.model_path = model_path
self._log = logging.getLogger(__name__)

def setup(self, worker):
worker.plugin_instance = self
worker.predictors = {}
print(f"registered worker {worker}")

def get_predictor(self):
worker = get_worker()
thread_id = threading.get_ident()
if thread_id not in worker.predictors:
predictor = tl2cgen.Predictor(self.model_path)
worker.predictors[thread_id] = predictor
self._log.info(
"Predictor created on worker %s for thread %s",
worker.address,
thread_id,
)
return worker.predictors[thread_id]
171 changes: 171 additions & 0 deletions odc/stats/plugins/lc_ml_treelite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
Base class for treelite models in LandCover PipeLine
"""

from abc import abstractmethod
from typing import Dict, Sequence, Optional

import os
import numpy as np
import numexpr as ne
import xarray as xr
import dask.array as da
from dask.distributed import get_worker

from datacube.model import Dataset
from datacube.utils.geometry import GeoBox
from odc.algo._memsink import yxbt_sink, yxt_sink
from odc.algo.io import load_with_native_transform

from odc.stats._algebra import expr_eval
from ._registry import StatsPluginInterface
from ._worker import TreeliteModelPlugin
import tl2cgen


def mask_and_predict(
block, block_info=None, ptype="categorical", nodata=np.nan, output_dtype="float32"
):
worker = get_worker()
plugin_instance = worker.plugin_instance
predictor = plugin_instance.get_predictor()

block_flat = block.reshape(-1, block.shape[-1])
# mask nodata and non-veg
mask_flat = ne.evaluate(
"where((a==a)&(b>0), 1, 0)",
local_dict={"a": block_flat[:, 0], "b": block_flat[:, -1]},
).astype("bool")
block_masked = block_flat[mask_flat, :-1]

prediction = np.full(
(block.shape[0] * block.shape[1], 1), nodata, dtype=output_dtype
)
if block_masked.shape[0] > 0:
dmat = tl2cgen.DMatrix(block_masked)
output_data = predictor.predict(dmat).squeeze(axis=1)
if ptype == "categorical":
prediction[mask_flat] = output_data.argmax(axis=-1)[..., np.newaxis]
else:
prediction[mask_flat] = output_data
return prediction.reshape(*block.shape[:-1])


class StatsMLTree(StatsPluginInterface):
NAME = "ga_ls_ml_tree"
SHORT_NAME = NAME
VERSION = "0.0.1"
PRODUCT_FAMILY = "lccs"

def __init__(
self,
output_classes: Dict,
model_path: str,
mask_bands: Optional[Dict] = None,
**kwargs,
):
super().__init__(**kwargs)
if not os.path.exists(model_path):
raise FileNotFoundError(f"{self.model_path} not found")
self.dask_worker_plugin = TreeliteModelPlugin(model_path)
self.output_classes = output_classes
self.mask_bands = mask_bands

def input_data(
self, datasets: Sequence[Dataset], geobox: GeoBox, **kwargs
) -> xr.Dataset:
# load data in the same time and location but different sensors
data_vars = {}

for ds in datasets:
if "gm" in ds.type.name:
input_bands = self.input_bands[:-1]
else:
input_bands = self.input_bands[-1:]

xx = load_with_native_transform(
[ds],
bands=input_bands,
geobox=geobox,
native_transform=self.native_transform,
basis=self.basis,
groupby=None,
fuser=None,
resampling=self.resampling,
chunks={"y": -1, "x": -1},
optional_bands=self.optional_bands,
**kwargs,
)
if "gm" in ds.type.name:
input_array = yxbt_sink(
xx,
(self.chunks["x"], self.chunks["y"], -1, -1),
dtype="float32",
name=ds.type.name + "_yxbt",
).squeeze("spec", drop=True)
data_vars[ds.type.name] = input_array
else:
for var in xx.data_vars:
data_vars[var] = yxt_sink(
xx[var].astype("uint8"),
(self.chunks["x"], self.chunks["y"], -1),
name=ds.type.name + "_yxt",
).squeeze("spec", drop=True)

coords = dict((dim, input_array.coords[dim]) for dim in input_array.dims)
return xr.Dataset(data_vars=data_vars, coords=coords)

def preprocess_predict_input(self, xx: xr.Dataset):
images = []
for var in xx.data_vars:
image = xx[var].data
if var not in self.mask_bands:
nodata = xx[var].attrs.get("nodata", -999)
image = expr_eval(
"where((a<=nodata), _nan, a)",
{
"a": image,
},
name="convert_dtype",
dtype="float32",
**{"nodata": nodata, "_nan": np.nan},
)
images += [image]
else:
veg_mask = expr_eval(
"where(a==_v, 1, 0)",
{"a": image},
name="make_mask",
dtype="float32",
**{"_v": int(self.mask_bands[var])},
)

images = [
da.concatenate([image, veg_mask[..., np.newaxis]], axis=-1).rechunk(
(None, None, image.shape[-1] + veg_mask.shape[-1])
)
for image in images
]
return images

@abstractmethod
def predict(self, input_array):
pass

@abstractmethod
def aggregate_results_from_group(self, predict_output):
pass

def reduce(self, xx: xr.Dataset) -> xr.Dataset:
images = self.preprocess_predict_input(xx)
res = []

for image in images:
res += [self.predict(image)]

res = self.aggregate_results_from_group(res)
attrs = xx.attrs.copy()
dims = list(xx.dims.keys())[:2]
data_vars = {"predict_output": xr.DataArray(res, dims=dims, attrs=attrs)}
coords = {dim: xx.coords[dim] for dim in dims}
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs)
4 changes: 3 additions & 1 deletion odc/stats/plugins/lc_tf_urban.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def input_data(
**kwargs,
)
input_array = yxbt_sink(
xx, (self.crop_size[0], self.crop_size[0], -1, -1)
xx,
(self.crop_size[0], self.crop_size[0], -1, -1),
name=ds.type.name + "_yxbt",
).squeeze("spec", drop=True)
data_vars[ds.type.name] = input_array

Expand Down
Loading

0 comments on commit db360b0

Please sign in to comment.