Skip to content

Commit

Permalink
feat: support empty arrays, improve ibis.array() API
Browse files Browse the repository at this point in the history
Picking out the array stuff from #8666
  • Loading branch information
NickCrews committed Jun 29, 2024
1 parent 33ec754 commit a8e6dd8
Show file tree
Hide file tree
Showing 12 changed files with 192 additions and 43 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/dask/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def concat(cls, dfs, **kwargs):

@classmethod
def asseries(cls, value, like=None):
"""Ensure that value is a pandas Series object, broadcast if necessary."""
"""Ensure that value is a dask Series object, broadcast if necessary."""

if isinstance(value, dd.Series):
return value
Expand All @@ -50,7 +50,7 @@ def asseries(cls, value, like=None):
elif isinstance(value, pd.Series):
return dd.from_pandas(value, npartitions=1)
elif like is not None:
if isinstance(value, (tuple, list, dict)):
if isinstance(value, (tuple, list, dict, np.ndarray)):
fn = lambda df: pd.Series([value] * len(df), index=df.index)
else:
fn = lambda df: pd.Series(value, index=df.index)
Expand Down
9 changes: 8 additions & 1 deletion ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def visit(cls, op: ops.Node, **kwargs):

@classmethod
def visit(cls, op: ops.Literal, value, dtype):
if dtype.is_interval():
if value is None:
value = None
elif dtype.is_interval():
value = pd.Timedelta(value, dtype.unit.short)
elif dtype.is_array():
value = np.array(value)
Expand Down Expand Up @@ -219,6 +221,11 @@ def visit(cls, op: ops.FindInSet, needle, values):
result = np.select(condlist, choicelist, default=-1)
return pd.Series(result, name=op.name)

@classmethod
def visit(cls, op: ops.EmptyArray, dtype):
pdt = PandasType.from_ibis(dtype)
return np.array([], dtype=pdt)

@classmethod
def visit(cls, op: ops.Array, exprs):
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)
Expand Down
24 changes: 16 additions & 8 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,27 @@ def literal(op, **_):
value = op.value
dtype = op.dtype

if dtype.is_array():
value = pl.Series("", value)
typ = PolarsType.from_ibis(dtype)
val = pl.lit(value, dtype=typ)
return val.implode()
# There are some interval types that _make_duration() can handle,
# but PolarsType.from_ibis can't, so we need to handle them here.
if dtype.is_interval():
return _make_duration(value, dtype)

Check warning on line 93 in ibis/backends/polars/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/polars/compiler.py#L93

Added line #L93 was not covered by tests

typ = PolarsType.from_ibis(dtype)

Check warning on line 95 in ibis/backends/polars/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/polars/compiler.py#L95

Added line #L95 was not covered by tests
if value is None:
return pl.lit(None, dtype=typ)

Check warning on line 97 in ibis/backends/polars/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/polars/compiler.py#L97

Added line #L97 was not covered by tests
elif dtype.is_array():
return pl.lit(pl.Series("", value).implode(), dtype=typ)

Check warning on line 99 in ibis/backends/polars/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/polars/compiler.py#L99

Added line #L99 was not covered by tests
elif dtype.is_struct():
values = [
pl.lit(v, dtype=PolarsType.from_ibis(dtype[k])).alias(k)
for k, v in value.items()
]
return pl.struct(values)
elif dtype.is_interval():
return _make_duration(value, dtype)
elif dtype.is_null():
return pl.lit(value)
elif dtype.is_binary():
return pl.lit(value)
else:
typ = PolarsType.from_ibis(dtype)
return pl.lit(op.value, dtype=typ)


Expand Down Expand Up @@ -973,6 +975,12 @@ def array_concat(op, **kw):
return result


@translate.register(ops.EmptyArray)

Check warning on line 978 in ibis/backends/polars/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/polars/compiler.py#L978

Added line #L978 was not covered by tests
def empty_array(op, **kw):
pdt = PolarsType.from_ibis(op.dtype)
return pl.lit([], dtype=pdt)

Check warning on line 981 in ibis/backends/polars/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/polars/compiler.py#L980-L981

Added lines #L980 - L981 were not covered by tests


@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.exprs]
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,9 @@ def visit_InSubquery(self, op, *, rel, needle):
query = sg.select(STAR).from_(query)
return needle.isin(query=query)

def visit_EmptyArray(self, op, *, dtype):
return self.cast(self.f.array(), dtype)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)

Expand Down
74 changes: 69 additions & 5 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PySparkAnalysisException,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.collections import frozendict

pytestmark = [
Expand Down Expand Up @@ -72,6 +73,74 @@
# list.


def test_array_factory(con):
a = ibis.array([1, 2, 3])
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a) == [1, 2, 3]

a2 = ibis.array(a)
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a2) == [1, 2, 3]


def test_array_factory_typed(con):
typed = ibis.array([1, 2, 3], type="array<string>")
assert con.execute(typed) == ["1", "2", "3"]

typed2 = ibis.array(ibis.array([1, 2, 3]), type="array<string>")
assert con.execute(typed2) == ["1", "2", "3"]


@pytest.mark.notimpl("flink", raises=Py4JJavaError)
def test_array_factory_empty(con):
with pytest.raises(ValidationError):
ibis.array([])

empty_typed = ibis.array([], type="array<string>")
assert empty_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(empty_typed) == []


@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
@pytest.mark.notyet(
"flink", raises=Py4JJavaError, reason="Parameters must be of the same type"
)
def test_array_factory_null(con):
with pytest.raises(ValidationError):
ibis.array(None)
with pytest.raises(ValidationError):
ibis.array(None, type="int64")
none_typed = ibis.array(None, type="array<string>")
assert none_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(none_typed) is None

nones = ibis.array([None, None], type="array<string>")
assert nones.type() == dt.Array(value_type=dt.string)
assert con.execute(nones) == [None, None]

# Execute a real value here, so the backends that don't support arrays
# actually xfail as we expect them to.
# Otherwise would have to @mark.xfail every test in this file besides this one.
assert con.execute(ibis.array([1, 2])) == [1, 2]


@pytest.mark.broken(
["datafusion", "flink", "polars"],
raises=AssertionError,
reason="[None, 1] executes to [np.nan, 1.0]",
)
def test_array_factory_null_mixed(con):
none_and_val = ibis.array([None, 1])
assert none_and_val.type() == dt.Array(value_type=dt.Int8)
assert con.execute(none_and_val) == [None, 1]

none_and_val_typed = ibis.array([None, 1], type="array<string>")
assert none_and_val_typed.type() == dt.Array(value_type=dt.String)
assert con.execute(none_and_val_typed) == [None, "1"]


def test_array_column(backend, alltypes, df):
expr = ibis.array(
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
Expand Down Expand Up @@ -1354,11 +1423,6 @@ def test_unnest_range(con):
id="array",
marks=[
pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest),
pytest.mark.broken(
["polars"],
reason="expression input not supported with nested arrays",
raises=TypeError,
),
],
),
],
Expand Down
7 changes: 3 additions & 4 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,13 +1431,12 @@ def query(t, group_cols):
snapshot.assert_match(str(ibis.to_sql(t3, dialect=con.name)), "out.sql")


@pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.notyet(
["datafusion", "impala", "mssql", "mysql", "sqlite"],
["datafusion", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
reason="backend doesn't support arrays and we don't implement pivot_longer with unions yet",
raises=com.OperationNotDefinedError,
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.broken(
["trino"],
reason="invalid code generated for unnesting a struct",
Expand Down
18 changes: 14 additions & 4 deletions ibis/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ibis.array([432]),
marks=[
pytest.mark.never(
["mysql", "mssql", "oracle", "impala", "sqlite"],
["exasol", "mysql", "mssql", "oracle", "impala", "sqlite"],
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
reason="arrays not supported in the backend",
),
Expand All @@ -30,8 +30,18 @@
ibis.struct(dict(abc=432)),
marks=[
pytest.mark.never(
["impala", "mysql", "sqlite", "mssql", "exasol"],
raises=(NotImplementedError, exc.UnsupportedBackendType),
[
"exasol",
"impala",
"mysql",
"sqlite",
"mssql",
],
raises=(
exc.OperationNotDefinedError,
NotImplementedError,
exc.UnsupportedBackendType,
),
reason="structs not supported in the backend",
),
pytest.mark.notimpl(
Expand Down Expand Up @@ -104,7 +114,7 @@ def test_isin_bug(con, snapshot):
@pytest.mark.notyet(
["datafusion", "exasol", "oracle", "flink", "risingwave"],
reason="no unnest support",
raises=exc.OperationNotDefinedError,
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
)
@pytest.mark.notyet(
["sqlite", "mysql", "druid", "impala", "mssql"], reason="no unnest support upstream"
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,18 +835,18 @@ def test_capitalize(con, inp, expected):
assert pd.isnull(result)


@pytest.mark.never(
["exasol", "impala", "mssql", "mysql", "sqlite"],
reason="Backend doesn't support arrays",
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
)
@pytest.mark.notimpl(
[
"dask",
"pandas",
"polars",
"oracle",
"flink",
"sqlite",
"mssql",
"mysql",
"exasol",
"impala",
],
raises=com.OperationNotDefinedError,
)
Expand Down
16 changes: 13 additions & 3 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,25 @@
from ibis.expr.operations.core import Unary, Value


@public
class EmptyArray(Value):
"""Construct an Empty array."""

dtype: dt.Array
shape = ds.scalar


@public
class Array(Value):
"""Construct an array."""

exprs: VarTuple[Value]

@attribute
def shape(self):
return rlz.highest_precedence_shape(self.exprs)
shape = rlz.shape_like("exprs")

def __init__(self, exprs):
assert len(exprs) > 0, "Use EmptyArray to create an empty array"
super().__init__(exprs=exprs)

@attribute
def dtype(self):
Expand Down
4 changes: 4 additions & 0 deletions ibis/expr/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from public import public

import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
Expand All @@ -16,6 +17,9 @@

@public
def highest_precedence_shape(nodes):
nodes = tuple(nodes)
if len(nodes) == 0:
return ds.scalar

Check warning on line 22 in ibis/expr/rules.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/rules.py#L22

Added line #L22 was not covered by tests
return max(node.shape for node in nodes)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
DummyTable
foo: Array([1])
foo: Array(exprs=[1], dtype=array<int8>)
Loading

0 comments on commit a8e6dd8

Please sign in to comment.