From d86902db7a5c43845ef6d7966915bff52814b2cc Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sun, 12 May 2024 06:05:17 -0800 Subject: [PATCH] feat: Improve array(), map(), and struct fixes https://github.com/ibis-project/ibis/issues/8289 This does a lot of changes. It was hard for me to separate them out as I implemented them. But now that it's all hashed out, I can try to split this up into separate commits if you want. But that might be sorta hard in some cases. One this is adding support for passing in None to all these constructors. These use the new `ibis.null()` API to return `op.Literal(None, )`s Make these constructors idempotent: you can pass in existing Expressions into array(), etc. The type argument for all of these now always has an effect, not just when passing in python literals. So basically it acts like a cast. A big structural change is that now ops.Array has an optional attribute "dtype", so if you pass in a 0-length sequence of values the op still knows what dtype it is. Several of the backends were always broken here, they just weren't getting caught. I marked them as broken, we can fix them in a followup. You can test this locally with eg `pytest -m -k factory ibis/backends/tests/test_array.py ibis/backends/tests/test_map.py ibis/backends/tests/test_struct.py` Also, fix a typing bug: map() can accept ArrayValues, not just ArrayColumns. Also, fix executing Literal(None) on pandas and polars, 0-length arrays on polars Also, fixing converting dtypes on clickhouse, Structs should be converted to nonnullable dtypes. Also, implement ops.StructColumn on pandas and dask --- ibis/backends/dask/executor.py | 8 ++- ibis/backends/exasol/compiler.py | 1 + ibis/backends/pandas/executor.py | 18 ++++-- ibis/backends/polars/compiler.py | 19 +++--- ibis/backends/risingwave/compiler.py | 7 ++- ibis/backends/sql/compiler.py | 7 ++- ibis/backends/sql/datatypes.py | 6 +- ibis/backends/sqlite/compiler.py | 1 + ibis/backends/tests/test_array.py | 58 +++++++++++++++++-- ibis/backends/tests/test_generic.py | 4 +- ibis/backends/tests/test_map.py | 16 ++++- ibis/backends/tests/test_sql.py | 76 ++++++++++++------------ ibis/backends/tests/test_struct.py | 87 ++++++++++++++++++++-------- ibis/expr/operations/arrays.py | 19 ++++-- ibis/expr/operations/structs.py | 9 +-- ibis/expr/types/arrays.py | 58 ++++++++++++++----- ibis/expr/types/maps.py | 69 +++++++++++++++++----- ibis/expr/types/structs.py | 69 ++++++++++++++-------- 18 files changed, 381 insertions(+), 151 deletions(-) diff --git a/ibis/backends/dask/executor.py b/ibis/backends/dask/executor.py index e35a4140481ce..ac0c28473187e 100644 --- a/ibis/backends/dask/executor.py +++ b/ibis/backends/dask/executor.py @@ -155,11 +155,17 @@ def mapper(df, cases): return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype) @classmethod - def visit(cls, op: ops.Array, exprs): + def visit(cls, op: ops.Array, exprs, dtype): return cls.rowwise( lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object ) + @classmethod + def visit(cls, op: ops.StructColumn, names, values): + return cls.rowwise( + lambda row: dict(zip(names, row)), values, name=op.name, dtype=object + ) + @classmethod def visit(cls, op: ops.ArrayConcat, arg): dtype = PandasType.from_ibis(op.dtype) diff --git a/ibis/backends/exasol/compiler.py b/ibis/backends/exasol/compiler.py index 0940c80a182e7..11ac0b5f5bbee 100644 --- a/ibis/backends/exasol/compiler.py +++ b/ibis/backends/exasol/compiler.py @@ -75,6 +75,7 @@ class ExasolCompiler(SQLGlotCompiler): ops.StringSplit, ops.StringToDate, ops.StringToTimestamp, + ops.StructColumn, ops.TimeDelta, ops.TimestampAdd, ops.TimestampBucket, diff --git a/ibis/backends/pandas/executor.py b/ibis/backends/pandas/executor.py index 858f49173464b..2ab04d52791b5 100644 --- a/ibis/backends/pandas/executor.py +++ b/ibis/backends/pandas/executor.py @@ -49,12 +49,14 @@ def visit(cls, op: ops.Node, **kwargs): @classmethod def visit(cls, op: ops.Literal, value, dtype): + if value is None: + return None if dtype.is_interval(): - value = pd.Timedelta(value, dtype.unit.short) - elif dtype.is_array(): - value = np.array(value) - elif dtype.is_date(): - value = pd.Timestamp(value, tz="UTC").tz_localize(None) + return pd.Timedelta(value, dtype.unit.short) + if dtype.is_array(): + return np.array(value) + if dtype.is_date(): + return pd.Timestamp(value, tz="UTC").tz_localize(None) return value @classmethod @@ -220,9 +222,13 @@ def visit(cls, op: ops.FindInSet, needle, values): return pd.Series(result, name=op.name) @classmethod - def visit(cls, op: ops.Array, exprs): + def visit(cls, op: ops.Array, exprs, dtype): return cls.rowwise(lambda row: np.array(row, dtype=object), exprs) + @classmethod + def visit(cls, op: ops.StructColumn, names, values): + return cls.rowwise(lambda row: dict(zip(names, row)), values) + @classmethod def visit(cls, op: ops.ArrayConcat, arg): return cls.rowwise(lambda row: np.concatenate(row.values), arg) diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index a793db0c609b8..4f4ab8cbc2e7e 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -86,10 +86,14 @@ def _make_duration(value, dtype): def literal(op, **_): value = op.value dtype = op.dtype + if dtype.is_interval(): + return _make_duration(value, dtype) - if dtype.is_array(): + typ = PolarsType.from_ibis(dtype) + if value is None: + return pl.lit(None, dtype=typ) + elif dtype.is_array(): value = pl.Series("", value) - typ = PolarsType.from_ibis(dtype) val = pl.lit(value, dtype=typ) return val.implode() elif dtype.is_struct(): @@ -98,14 +102,11 @@ def literal(op, **_): for k, v in value.items() ] return pl.struct(values) - elif dtype.is_interval(): - return _make_duration(value, dtype) elif dtype.is_null(): return pl.lit(value) elif dtype.is_binary(): return pl.lit(value) else: - typ = PolarsType.from_ibis(dtype) return pl.lit(op.value, dtype=typ) @@ -980,9 +981,11 @@ def array_concat(op, **kw): @translate.register(ops.Array) -def array_column(op, **kw): - cols = [translate(col, **kw) for col in op.exprs] - return pl.concat_list(cols) +def array_literal(op, **kw): + if len(op.exprs) > 0: + return pl.concat_list([translate(col, **kw) for col in op.exprs]) + else: + return pl.lit([], dtype=PolarsType.from_ibis(op.dtype)) @translate.register(ops.ArrayCollect) diff --git a/ibis/backends/risingwave/compiler.py b/ibis/backends/risingwave/compiler.py index 052b2c8fdea7d..d540d104facb1 100644 --- a/ibis/backends/risingwave/compiler.py +++ b/ibis/backends/risingwave/compiler.py @@ -8,7 +8,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.postgres.compiler import PostgresCompiler -from ibis.backends.sql.compiler import ALL_OPERATIONS +from ibis.backends.sql.compiler import ALL_OPERATIONS, SQLGlotCompiler from ibis.backends.sql.datatypes import RisingWaveType from ibis.backends.sql.dialects import RisingWave @@ -51,6 +51,11 @@ def visit_Correlation(self, op, *, left, right, how, where): op, left=left, right=right, how=how, where=where ) + def visit_StructColumn(self, op, *, names, values): + # The parent Postgres compiler uses the ROW() function, + # but the grandparent SQLGlot compiler uses the correct syntax + return SQLGlotCompiler.visit_StructColumn(self, op, names=names, values=values) + def visit_TimestampTruncate(self, op, *, arg, unit): unit_mapping = { "Y": "year", diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index fb0c5a974c8b7..158b0ebc0d362 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -970,8 +970,11 @@ def visit_InSubquery(self, op, *, rel, needle): query = sg.select(STAR).from_(query) return needle.isin(query=query) - def visit_Array(self, op, *, exprs): - return self.f.array(*exprs) + def visit_Array(self, op, *, exprs, dtype): + result = self.f.array(*exprs) + if len(exprs) == 0: + return self.cast(result, dtype) + return result def visit_StructColumn(self, op, *, names, values): return sge.Struct.from_arg_list( diff --git a/ibis/backends/sql/datatypes.py b/ibis/backends/sql/datatypes.py index 4e492fb15f527..80c57d876b04b 100644 --- a/ibis/backends/sql/datatypes.py +++ b/ibis/backends/sql/datatypes.py @@ -1007,8 +1007,10 @@ class ClickHouseType(SqlglotType): def from_ibis(cls, dtype: dt.DataType) -> sge.DataType: """Convert a sqlglot type to an ibis type.""" typ = super().from_ibis(dtype) - if dtype.nullable and not (dtype.is_map() or dtype.is_array()): - # map cannot be nullable in clickhouse + # nested types cannot be nullable in clickhouse + if dtype.nullable and not ( + dtype.is_map() or dtype.is_array() or dtype.is_struct() + ): return sge.DataType(this=typecode.NULLABLE, expressions=[typ]) else: return typ diff --git a/ibis/backends/sqlite/compiler.py b/ibis/backends/sqlite/compiler.py index 2e7e6b279c16d..ecad6df16150e 100644 --- a/ibis/backends/sqlite/compiler.py +++ b/ibis/backends/sqlite/compiler.py @@ -60,6 +60,7 @@ class SQLiteCompiler(SQLGlotCompiler): ops.TimestampDiff, ops.StringToDate, ops.StringToTimestamp, + ops.StructColumn, ops.TimeDelta, ops.DateDelta, ops.TimestampDelta, diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 7340540f7269d..5a88913ab2922 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -30,6 +30,7 @@ PySparkAnalysisException, TrinoUserError, ) +from ibis.common.annotations import ValidationError from ibis.common.collections import frozendict pytestmark = [ @@ -66,11 +67,63 @@ pytest.mark.notimpl(["druid", "oracle"], raises=Exception), ] +mark_notyet_datafusion = pytest.mark.notyet( + "datafusion", + raises=Exception, + reason="datafusion can't handle array casts yet. https://github.com/apache/datafusion/issues/10464", +) + # NB: We don't check whether results are numpy arrays or lists because this # varies across backends. At some point we should unify the result type to be # list. +def test_array_factory(con): + a = ibis.array([1, 2, 3]) + assert con.execute(a) == [1, 2, 3] + + a2 = ibis.array(a) + assert con.execute(a2) == [1, 2, 3] + + +@mark_notyet_datafusion +def test_array_factory_typed(con): + typed = ibis.array([1, 2, 3], type="array") + assert con.execute(typed) == ["1", "2", "3"] + + typed2 = ibis.array(ibis.array([1, 2, 3]), type="array") + assert con.execute(typed2) == ["1", "2", "3"] + + +@mark_notyet_datafusion +@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError) +def test_array_factory_empty(con): + with pytest.raises(ValidationError): + ibis.array([]) + + empty_typed = ibis.array([], type="array") + assert empty_typed.type() == dt.Array(value_type=dt.string) + assert con.execute(empty_typed) == [] + + +@mark_notyet_datafusion +@pytest.mark.notyet( + "clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL" +) +def test_array_factory_null(con): + with pytest.raises(ValidationError): + ibis.array(None) + with pytest.raises(ValidationError): + ibis.array(None, type="int64") + none_typed = ibis.array(None, type="array") + assert none_typed.type() == dt.Array(value_type=dt.string) + assert con.execute(none_typed) is None + # Execute a real value here, so the backends that don't support arrays + # actually xfail as we expect them to. + # Otherwise would have to @mark.xfail every test in this file besides this one. + assert con.execute(ibis.array([1, 2])) == [1, 2] + + def test_array_column(backend, alltypes, df): expr = ibis.array( [alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)] @@ -913,11 +966,6 @@ def test_zip_null(con, fn): @builtin_array -@pytest.mark.notyet( - ["clickhouse"], - raises=ClickHouseDatabaseError, - reason="https://github.com/ClickHouse/ClickHouse/issues/41112", -) @pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError) @pytest.mark.notimpl(["risingwave"], raises=PsycoPg2ProgrammingError) @pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index b6677dc4944c8..9ccbd1e3a565c 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1230,9 +1230,7 @@ def query(t, group_cols): snapshot.assert_match(str(ibis.to_sql(t3, dialect=con.name)), "out.sql") -@pytest.mark.notimpl( - ["dask", "pandas", "oracle", "exasol"], raises=com.OperationNotDefinedError -) +@pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["druid"], raises=AssertionError) @pytest.mark.notyet( ["datafusion", "impala", "mssql", "mysql", "sqlite"], diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index 63cdf728aabf4..b09efb46fa289 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -11,6 +11,7 @@ import ibis.common.exceptions as exc import ibis.expr.datatypes as dt from ibis.backends.tests.errors import PsycoPg2InternalError, Py4JJavaError +from ibis.common.annotations import ValidationError pytestmark = [ pytest.mark.never( @@ -39,6 +40,19 @@ ) +@pytest.mark.notyet("clickhouse", reason="nested types can't be NULL") +@mark_notimpl_risingwave_hstore +def test_map_factory(con): + assert con.execute(ibis.map(None, type="map")) is None + assert con.execute(ibis.map({"a": "b"}, type="map")) == {"a": "b"} + with pytest.raises(ValidationError): + ibis.map(None) + with pytest.raises(ValidationError): + ibis.map(None, type="array") + with pytest.raises(ValidationError): + ibis.map({1: 2}, type="array") + + @pytest.mark.notyet("clickhouse", reason="nested types can't be NULL") @pytest.mark.broken(["pandas", "dask"], reason="TypeError: iteration over a 0-d array") @pytest.mark.notimpl( @@ -669,6 +683,6 @@ def test_map_keys_unnest(backend): @mark_notimpl_risingwave_hstore def test_map_contains_null(con): - expr = ibis.map(["a"], ibis.literal([None], type="array")) + expr = ibis.map(["a"], ibis.array([None], type="array")) assert con.execute(expr.contains("a")) assert not con.execute(expr.contains("b")) diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 097bbb9cb4577..33ec63264e7ac 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -10,52 +10,52 @@ sg = pytest.importorskip("sqlglot") -simple_literal = param(ibis.literal(1), id="simple_literal") -array_literal = param( - ibis.array([1]), - marks=[ - pytest.mark.never( - ["mysql", "mssql", "oracle", "impala", "sqlite"], - raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType), - reason="arrays not supported in the backend", - ), - ], - id="array_literal", -) -no_structs = pytest.mark.never( - ["impala", "mysql", "sqlite", "mssql", "exasol"], - raises=(NotImplementedError, exc.UnsupportedBackendType), - reason="structs not supported in the backend", -) -no_struct_literals = pytest.mark.notimpl( - ["mssql"], reason="struct literals are not yet implemented" -) -not_sql = pytest.mark.never( - ["pandas", "dask"], - raises=(exc.IbisError, NotImplementedError, ValueError), - reason="Not a SQL backend", -) -no_sql_extraction = pytest.mark.notimpl( - ["polars"], reason="Not clear how to extract SQL from the backend" -) - @pytest.mark.parametrize( - "expr", + "expr,contains", [ - simple_literal, - array_literal, + param(ibis.literal(432), "432", id="simple_literal"), param( - ibis.struct(dict(a=1)), - marks=[no_structs, no_struct_literals], + ibis.array([432]), + "432", + marks=[ + pytest.mark.never( + ["mysql", "mssql", "oracle", "impala", "sqlite"], + raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType), + reason="arrays not supported in the backend", + ), + ], + id="array_literal", + ), + param( + ibis.struct(dict(abc=432)), + "432", + marks=[ + pytest.mark.never( + ["impala", "mysql", "sqlite", "mssql", "exasol"], + raises=( + exc.OperationNotDefinedError, + NotImplementedError, + exc.UnsupportedBackendType, + ), + reason="structs not supported in the backend", + ), + pytest.mark.notimpl( + ["mssql"], reason="struct literals are not yet implemented" + ), + ], id="struct_literal", ), ], ) -@not_sql -@no_sql_extraction -def test_literal(backend, expr): - assert ibis.to_sql(expr, dialect=backend.name()) +@pytest.mark.never( + ["pandas", "dask"], + raises=(exc.IbisError, NotImplementedError, ValueError), + reason="Not a SQL backend", +) +@pytest.mark.notimpl(["polars"], reason="Not clear how to extract SQL from the backend") +def test_literal(backend, expr, contains): + assert contains in ibis.to_sql(expr, dialect=backend.name()) @pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL") diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index e7f078f7591a1..7bdb10c48b2bf 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -21,7 +21,8 @@ Py4JJavaError, PySparkAnalysisException, ) -from ibis.common.exceptions import IbisError, OperationNotDefinedError +from ibis.common.annotations import ValidationError +from ibis.common.exceptions import IbisError pytestmark = [ pytest.mark.never(["mysql", "sqlite", "mssql"], reason="No struct support"), @@ -29,6 +30,57 @@ pytest.mark.notimpl(["datafusion", "druid", "oracle", "exasol"]), ] +mark_notimpl_postgres_literals = pytest.mark.notimpl( + "postgres", reason="struct literals not implemented", raises=PsycoPg2SyntaxError +) + + +@pytest.mark.broken("postgres", reason="JSON handling is buggy") +def test_struct_factory(con): + s = ibis.struct({"a": 1, "b": 2}) + assert con.execute(s) == {"a": 1, "b": 2} + + s2 = ibis.struct(s) + assert con.execute(s2) == {"a": 1, "b": 2} + + typed = ibis.struct({"a": 1, "b": 2}, type="struct") + assert con.execute(typed) == {"a": "1", "b": "2"} + + typed2 = ibis.struct(s, type="struct") + assert con.execute(typed2) == {"a": "1", "b": "2"} + + items = ibis.struct([("a", 1), ("b", 2)]) + assert con.execute(items) == {"a": 1, "b": 2} + + +@pytest.mark.parametrize("val", [{}, []]) +def test_struct_factory_empty(val): + with pytest.raises(ValidationError): + ibis.struct(val) + with pytest.raises(ValidationError): + ibis.struct(val, type="struct<>") + with pytest.raises(ValidationError): + ibis.struct(val, type="struct") + + +@mark_notimpl_postgres_literals +@pytest.mark.notyet( + "clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL" +) +@pytest.mark.broken( + "polars", reason=r"pl.lit(None, type='struct') gives {'a': None}" +) +def test_struct_factory_null(con): + with pytest.raises(ValidationError): + ibis.struct(None) + none_typed = ibis.struct(None, type="struct") + assert none_typed.type() == dt.Struct(fields={"a": dt.float64, "b": dt.float64}) + assert con.execute(none_typed) is None + # Execute a real value here, so the backends that don't support structs + # actually xfail as we expect them to. + # Otherwise would have to @mark.xfail every test in this file besides this one. + assert con.execute(ibis.struct({"a": 1, "b": 2})) == {"a": 1, "b": 2} + @pytest.mark.notimpl(["dask"]) @pytest.mark.parametrize( @@ -79,6 +131,9 @@ def test_all_fields(struct, struct_df): @pytest.mark.notimpl(["postgres", "risingwave"]) @pytest.mark.parametrize("field", ["a", "b", "c"]) +@pytest.mark.notyet( + ["flink"], reason="flink doesn't support creating struct columns from literals" +) def test_literal(backend, con, field): query = _STRUCT_LITERAL[field] dtype = query.type().to_pandas() @@ -88,7 +143,7 @@ def test_literal(backend, con, field): backend.assert_series_equal(result, expected.astype(dtype)) -@pytest.mark.notimpl(["postgres"]) +@mark_notimpl_postgres_literals @pytest.mark.parametrize("field", ["a", "b", "c"]) @pytest.mark.notyet( ["clickhouse"], reason="clickhouse doesn't support nullable nested types" @@ -101,7 +156,7 @@ def test_null_literal(backend, con, field): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave"]) +@pytest.mark.notimpl(["postgres", "risingwave"]) def test_struct_column(alltypes, df): t = alltypes expr = t.select(s=ibis.struct(dict(a=t.string_col, b=1, c=t.bigint_col))) @@ -113,7 +168,7 @@ def test_struct_column(alltypes, df): tm.assert_frame_equal(result, expected) -@pytest.mark.notimpl(["dask", "pandas", "postgres", "risingwave", "polars"]) +@pytest.mark.notimpl(["postgres", "risingwave", "polars"]) @pytest.mark.notyet( ["flink"], reason="flink doesn't support creating struct columns from collect" ) @@ -138,9 +193,6 @@ def test_collect_into_struct(alltypes): assert len(val.loc[result.group == "1"].iat[0]["key"]) == 730 -@pytest.mark.notimpl( - ["postgres"], reason="struct literals not implemented", raises=PsycoPg2SyntaxError -) @pytest.mark.notimpl( ["risingwave"], reason="struct literals not implemented", @@ -158,28 +210,16 @@ def test_field_access_after_case(con): ["postgres"], reason="struct literals not implemented", raises=PsycoPg2SyntaxError ) @pytest.mark.notimpl(["flink"], raises=IbisError, reason="not implemented in ibis") +@pytest.mark.notyet( + ["clickhouse"], raises=sg.ParseError, reason="sqlglot fails to parse" +) @pytest.mark.parametrize( "nullable", [ - param( - True, - marks=[ - pytest.mark.notyet( - ["clickhouse"], - raises=ClickHouseDatabaseError, - reason="ClickHouse doesn't support nested nullable types", - ) - ], - id="nullable", - ), + param(True, id="nullable"), param( False, marks=[ - pytest.mark.notyet( - ["clickhouse"], - raises=sg.ParseError, - reason="sqlglot fails to parse", - ), pytest.mark.notyet( ["polars"], raises=AssertionError, @@ -253,7 +293,6 @@ def test_keyword_fields(con, nullable): raises=PolarsColumnNotFoundError, reason="doesn't seem to support IN-style subqueries on structs", ) -@pytest.mark.notimpl(["pandas", "dask"], raises=OperationNotDefinedError) @pytest.mark.xfail_version( pyspark=["pyspark<3.5"], reason="requires pyspark 3.5", diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 68ee711a2da6b..75c9648da47fa 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -7,7 +7,7 @@ import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.rules as rlz -from ibis.common.annotations import attribute +from ibis.common.annotations import ValidationError, attribute from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.operations.core import Unary, Value @@ -15,15 +15,24 @@ @public class Array(Value): exprs: VarTuple[Value] + dtype: Optional[dt.Array] = None + + def __init__(self, exprs, dtype: dt.Array | None = None): + if len(exprs) == 0: + if dtype is None: + raise ValidationError("If values is empty, dtype must be provided") + if not isinstance(dtype, dt.Array): + raise ValidationError(f"dtype must be an array, got {dtype}") + elif dtype is None: + dtype = dt.Array(rlz.highest_precedence_dtype(exprs)) + super().__init__(exprs=exprs, dtype=dtype) @attribute def shape(self): + if len(self.exprs) == 0: + return ds.scalar return rlz.highest_precedence_shape(self.exprs) - @attribute - def dtype(self): - return dt.Array(rlz.highest_precedence_dtype(self.exprs)) - @public class ArrayLength(Unary): diff --git a/ibis/expr/operations/structs.py b/ibis/expr/operations/structs.py index 20c0c3dc0a4ef..2b3de036948ac 100644 --- a/ibis/expr/operations/structs.py +++ b/ibis/expr/operations/structs.py @@ -34,7 +34,9 @@ class StructColumn(Value): shape = rlz.shape_like("values") - def __init__(self, names, values): + def __init__(self, names: VarTuple[str], values: VarTuple[Value]): + if len(names) == 0: + raise ValidationError("StructColumn must have at least one field") if len(names) != len(values): raise ValidationError( f"Length of names ({len(names)}) does not match length of " @@ -43,6 +45,5 @@ def __init__(self, names, values): super().__init__(names=names, values=values) @attribute - def dtype(self) -> dt.DataType: - dtypes = (value.dtype for value in self.values) - return dt.Struct.from_tuples(zip(self.names, dtypes)) + def dtype(self) -> dt.Struct: + return dt.Struct.from_tuples(zip(self.names, (v.dtype for v in self.values))) diff --git a/ibis/expr/types/arrays.py b/ibis/expr/types/arrays.py index bea22cf537925..f1f11849c0156 100644 --- a/ibis/expr/types/arrays.py +++ b/ibis/expr/types/arrays.py @@ -5,14 +5,16 @@ from public import public +import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.common.annotations import ValidationError from ibis.common.deferred import Deferred, deferrable from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: from collections.abc import Iterable - import ibis.expr.types as ir from ibis.expr.types.typing import V import ibis.common.exceptions as com @@ -1067,7 +1069,11 @@ def __getitem__(self, index: int | ir.IntegerValue | slice) -> ir.Column: @public @deferrable -def array(values: Iterable[V]) -> ArrayValue: +def array( + values: ArrayValue | Iterable[V] | ir.NullValue | None, + *, + type: str | dt.DataType | None = None, +) -> ArrayValue: """Create an array expression. If any values are [column expressions](../concepts/datatypes.qmd) the @@ -1078,6 +1084,9 @@ def array(values: Iterable[V]) -> ArrayValue: ---------- values An iterable of Ibis expressions or Python literals + type + An instance of `ibis.expr.datatypes.DataType` or a string indicating + the Ibis type of `value`. eg `array`. Returns ------- @@ -1106,15 +1115,38 @@ def array(values: Iterable[V]) -> ArrayValue: │ [3, 42, ... +1] │ └──────────────────────┘ - >>> ibis.array([t.a, 42 + ibis.literal(5)]) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ Array() ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ array │ - ├──────────────────────┤ - │ [1, 47] │ - │ [2, 47] │ - │ [3, 47] │ - └──────────────────────┘ + >>> ibis.array([t.a, 42 + ibis.literal(5)], type="array") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Cast(Array(), array) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ array │ + ├───────────────────────────────┤ + │ [1.0, 47.0] │ + │ [2.0, 47.0] │ + │ [3.0, 47.0] │ + └───────────────────────────────┘ """ - return ops.Array(tuple(values)).to_expr() + type = dt.dtype(type) if type is not None else None + if type is not None and not isinstance(type, dt.Array): + raise ValidationError(f"dtype must be an array, got {type}") + + if isinstance(values, ir.Value): + if type is not None: + return values.cast(type) + elif isinstance(values, ArrayValue): + return values + else: + raise ValidationError( + f"If no type passed, values must be an array, got {values.type()}" + ) + + if values is None: + if type is None: + raise ValidationError("If values is None/NULL, dtype must be provided") + return ir.null(type) + + values = tuple(values) + if len(values) > 0 and type is not None: + return ops.Array(values).to_expr().cast(type) + else: + return ops.Array(values, type).to_expr() diff --git a/ibis/expr/types/maps.py b/ibis/expr/types/maps.py index 79d45d16e188e..f8fd6244c48e4 100644 --- a/ibis/expr/types/maps.py +++ b/ibis/expr/types/maps.py @@ -1,18 +1,21 @@ from __future__ import annotations +from collections.abc import Mapping from typing import TYPE_CHECKING, Any from public import public +import ibis import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.common.annotations import ValidationError from ibis.common.deferred import deferrable from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: - from collections.abc import Iterable, Mapping + from collections.abc import Iterable - import ibis.expr.types as ir - from ibis.expr.types.arrays import ArrayColumn + import ibis.expr.datatypes as dt @public @@ -435,8 +438,15 @@ def __getitem__(self, key: ir.Value) -> ir.Column: @public @deferrable def map( - keys: Iterable[Any] | Mapping[Any, Any] | ArrayColumn, - values: Iterable[Any] | ArrayColumn | None = None, + keys: Iterable[Any] + | Mapping[Any, Any] + | ir.ArrayValue + | MapValue + | ir.NullValue + | None, + values: Iterable[Any] | ir.ArrayValue | None = None, + *, + type: str | dt.DataType | None = None, ) -> MapValue: """Create a MapValue. @@ -449,6 +459,9 @@ def map( Keys of the map or `Mapping`. If `keys` is a `Mapping`, `values` must be `None`. values Values of the map or `None`. If `None`, the `keys` argument must be a `Mapping`. + type + An instance of `ibis.expr.datatypes.DataType` or a string indicating + the Ibis type of `value`. eg `map`. Returns ------- @@ -476,16 +489,42 @@ def map( │ ['a', 'b'] │ [1, 2] │ │ ['b'] │ [3] │ └──────────────────────┴──────────────────────┘ - >>> ibis.map(t.keys, t.values) - ┏━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ Map(keys, values) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━┩ - │ map │ - ├──────────────────────┤ - │ {'a': 1, 'b': 2} │ - │ {'b': 3} │ - └──────────────────────┘ + >>> ibis.map(t.keys, t.values, type="map") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Map(keys, Cast(values, array)) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ map │ + ├─────────────────────────────────────────┤ + │ {'a': 1.0, 'b': 2.0} │ + │ {'b': 3.0} │ + └─────────────────────────────────────────┘ """ - if values is None: + from ibis.expr import datatypes as dt + + type = dt.dtype(type) if type is not None else None + if type is not None and not isinstance(type, dt.Map): + raise ValidationError(f"dtype must be a map, got {type}") + + if isinstance(keys, Mapping) and values is None: keys, values = tuple(keys.keys()), tuple(keys.values()) + + if isinstance(keys, ir.Value) and values is None: + if type is not None: + return keys.cast(type) + elif isinstance(keys, MapValue): + return keys + else: + raise ValidationError( + f"If no type passed, value must be a map, got {keys.type()}" + ) + + if keys is None or values is None: + if type is None: + raise ValidationError("If keys is None/NULL, dtype must be provided") + return ir.null(type) + + k_type = dt.Array(value_type=type.key_type) if type is not None else None + v_type = dt.Array(value_type=type.value_type) if type is not None else None + keys = ibis.array(keys, type=k_type) + values = ibis.array(values, type=v_type) return ops.Map(keys, values).to_expr() diff --git a/ibis/expr/types/structs.py b/ibis/expr/types/structs.py index 65a16700318a8..8876b459b2da3 100644 --- a/ibis/expr/types/structs.py +++ b/ibis/expr/types/structs.py @@ -1,28 +1,33 @@ from __future__ import annotations -import collections from keyword import iskeyword from typing import TYPE_CHECKING from public import public +import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import ibis.expr.types as ir +from ibis.common.annotations import ValidationError from ibis.common.deferred import deferrable from ibis.common.exceptions import IbisError -from ibis.expr.types.generic import Column, Scalar, Value, literal +from ibis.expr.types.generic import Column, Scalar, Value if TYPE_CHECKING: from collections.abc import Iterable, Mapping, Sequence - import ibis.expr.datatypes as dt - import ibis.expr.types as ir from ibis.expr.types.typing import V @public @deferrable def struct( - value: Iterable[tuple[str, V]] | Mapping[str, V], + value: Iterable[tuple[str, V]] + | Mapping[str, V] + | StructValue + | ir.NullValue + | None, + *, type: str | dt.DataType | None = None, ) -> StructValue: """Create a struct expression. @@ -37,8 +42,7 @@ def struct( `(str, Value)`. type An instance of `ibis.expr.datatypes.DataType` or a string indicating - the Ibis type of `value`. This is only used if all of the input values - are Python literals. eg `struct`. + the Ibis type of `value`. eg `struct`. Returns ------- @@ -62,26 +66,45 @@ def struct( Create a struct column from a column and a scalar literal >>> t = ibis.memtable({"a": [1, 2, 3]}) - >>> ibis.struct([("a", t.a), ("b", "foo")]) - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ StructColumn() ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ struct │ - ├─────────────────────────────┤ - │ {'a': 1, 'b': 'foo'} │ - │ {'a': 2, 'b': 'foo'} │ - │ {'a': 3, 'b': 'foo'} │ - └─────────────────────────────┘ + >>> ibis.struct([("a", t.a), ("b", "foo")], type="struct") + ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ + ┃ Cast(StructColumn(), struct) ┃ + ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ + │ struct │ + ├─────────────────────────────────────────────────────┤ + │ {'a': 1.0, 'b': 'foo'} │ + │ {'a': 2.0, 'b': 'foo'} │ + │ {'a': 3.0, 'b': 'foo'} │ + └─────────────────────────────────────────────────────┘ """ import ibis.expr.operations as ops + type = dt.dtype(type) if type is not None else None + if type is not None and not isinstance(type, dt.Struct): + raise ValidationError(f"dtype must be an struct, got {type}") + + if isinstance(value, ir.Value): + if type is not None: + return value.cast(type) + elif isinstance(value, StructValue): + return value + else: + raise ValidationError( + f"If no type passed, value must be a struct, got {value.type()}" + ) + + if value is None: + if type is None: + raise ValidationError("If values is None/NULL, dtype must be provided") + return ir.null(type) + fields = dict(value) - if any(isinstance(value, Value) for value in fields.values()): - names = tuple(fields.keys()) - values = tuple(fields.values()) - return ops.StructColumn(names=names, values=values).to_expr() - else: - return literal(collections.OrderedDict(fields), type=type) + names = tuple(fields.keys()) + values = tuple(fields.values()) + result = ops.StructColumn(names=names, values=values).to_expr() + if type is not None: + result = result.cast(type) + return result @public