Skip to content

Commit

Permalink
[CALCITE-6408] Not-null ThreadLocal
Browse files Browse the repository at this point in the history
Make various ThreadLocal instances non-nullable. They must
have an initializer, but the caller can use the value without
checking whether it is null.
  • Loading branch information
julianhyde committed Sep 16, 2024
1 parent c1b0727 commit 740f2ee
Show file tree
Hide file tree
Showing 17 changed files with 269 additions and 114 deletions.
13 changes: 8 additions & 5 deletions core/src/main/java/org/apache/calcite/jdbc/CalcitePrepare.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.tools.RelRunner;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.TryThreadLocal;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.google.common.collect.ImmutableList;
Expand All @@ -70,8 +71,8 @@
*/
public interface CalcitePrepare {
Function0<CalcitePrepare> DEFAULT_FACTORY = CalcitePrepareImpl::new;
ThreadLocal<@Nullable Deque<Context>> THREAD_CONTEXT_STACK =
ThreadLocal.withInitial(ArrayDeque::new);
TryThreadLocal<Deque<Context>> THREAD_CONTEXT_STACK =
TryThreadLocal.withInitial(ArrayDeque::new);

ParseResult parse(Context context, String sql);

Expand Down Expand Up @@ -193,7 +194,7 @@ private static SparkHandler createHandler() {
}

public static void push(Context context) {
final Deque<Context> stack = castNonNull(THREAD_CONTEXT_STACK.get());
final Deque<Context> stack = THREAD_CONTEXT_STACK.get();
final List<String> path = context.getObjectPath();
if (path != null) {
for (Context context1 : stack) {
Expand All @@ -207,11 +208,13 @@ public static void push(Context context) {
}

public static Context peek() {
return castNonNull(castNonNull(THREAD_CONTEXT_STACK.get()).peek());
final Deque<Context> stack = THREAD_CONTEXT_STACK.get();
return castNonNull(stack.peek());
}

public static void pop(Context context) {
Context x = castNonNull(THREAD_CONTEXT_STACK.get()).pop();
final Deque<Context> stack = THREAD_CONTEXT_STACK.get();
Context x = castNonNull(stack).pop();
assert x == context;
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/org/apache/calcite/prepare/Prepare.java
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ private static boolean shouldTrim(RelNode rootRel) {
// For now, don't trim if there are more than 3 joins. The projects
// near the leaves created by trim migrate past joins and seem to
// prevent join-reordering.
return castNonNull(THREAD_TRIM.get()) || RelOptUtil.countJoins(rootRel) < 2;
return THREAD_TRIM.get() || RelOptUtil.countJoins(rootRel) < 2;
}

protected abstract void init(Class runtimeContextClass);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.calcite.util.DateString;
import org.apache.calcite.util.TimestampString;
import org.apache.calcite.util.TimestampWithTimeZoneString;
import org.apache.calcite.util.TryThreadLocal;
import org.apache.calcite.util.Util;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -137,7 +138,7 @@ private static int calendarUnitFor(TimeUnitRange timeUnitRange) {
* generate hundreds of ranges we'll later throw away. */
static ImmutableSortedSet<TimeUnitRange> extractTimeUnits(RexNode e) {
try (ExtractFinder finder = ExtractFinder.THREAD_INSTANCES.get()) {
assert requireNonNull(finder, "finder").timeUnits.isEmpty() && finder.opKinds.isEmpty()
assert finder.timeUnits.isEmpty() && finder.opKinds.isEmpty()
: "previous user did not clean up";
e.accept(finder);
return ImmutableSortedSet.copyOf(finder.timeUnits);
Expand Down Expand Up @@ -190,7 +191,7 @@ public FilterDateRangeRule(RelBuilderFactory relBuilderFactory) {
* If none of these, we cannot apply the rule. */
private static boolean containsRoundingExpression(Filter filter) {
try (ExtractFinder finder = ExtractFinder.THREAD_INSTANCES.get()) {
assert requireNonNull(finder, "finder").timeUnits.isEmpty() && finder.opKinds.isEmpty()
assert finder.timeUnits.isEmpty() && finder.opKinds.isEmpty()
: "previous user did not clean up";
filter.getCondition().accept(finder);
return finder.timeUnits.contains(TimeUnitRange.YEAR)
Expand Down Expand Up @@ -239,8 +240,8 @@ private static class ExtractFinder extends RexVisitorImpl<Void>
EnumSet.noneOf(TimeUnitRange.class);
private final Set<SqlKind> opKinds = EnumSet.noneOf(SqlKind.class);

private static final ThreadLocal<@Nullable ExtractFinder> THREAD_INSTANCES =
ThreadLocal.withInitial(ExtractFinder::new);
private static final TryThreadLocal<ExtractFinder> THREAD_INSTANCES =
TryThreadLocal.withInitial(ExtractFinder::new);

private ExtractFinder() {
super(true);
Expand Down
15 changes: 7 additions & 8 deletions core/src/main/java/org/apache/calcite/runtime/Hook.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,17 @@

import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.util.Holder;
import org.apache.calcite.util.TryThreadLocal;
import org.apache.calcite.util.Util;

import org.apiguardian.api.API;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;
import java.util.function.Function;

import static org.apache.calcite.linq4j.Nullness.castNonNull;

/**
* Collection of hooks that can be set by observers and are executed at various
* parts of the query preparation process.
Expand Down Expand Up @@ -111,8 +109,8 @@ public enum Hook {
new CopyOnWriteArrayList<>();

@SuppressWarnings("ImmutableEnumChecker")
private final ThreadLocal<@Nullable List<Consumer<Object>>> threadHandlers =
ThreadLocal.withInitial(ArrayList::new);
private final TryThreadLocal<List<Consumer<Object>>> threadHandlers =
TryThreadLocal.withInitial(ArrayList::new);

/** Adds a handler for this Hook.
*
Expand Down Expand Up @@ -156,7 +154,7 @@ private boolean remove(Consumer handler) {
/** Adds a handler for this thread. */
public <T> Closeable addThread(final Consumer<T> handler) {
//noinspection unchecked
castNonNull(threadHandlers.get()).add((Consumer<Object>) handler);
threadHandlers.get().add((Consumer<Object>) handler);
return () -> removeThread(handler);
}

Expand All @@ -182,8 +180,9 @@ private static <T, R> Consumer<T> functionConsumer(
}

/** Removes a thread handler from this Hook. */
@SuppressWarnings({"rawtypes", "UnusedReturnValue"})
private boolean removeThread(Consumer handler) {
return castNonNull(threadHandlers.get()).remove(handler);
return threadHandlers.get().remove(handler);
}

// CHECKSTYLE: IGNORE 1
Expand Down Expand Up @@ -211,7 +210,7 @@ public void run(Object arg) {
for (Consumer<Object> handler : handlers) {
handler.accept(arg);
}
for (Consumer<Object> handler : castNonNull(threadHandlers.get())) {
for (Consumer<Object> handler : threadHandlers.get()) {
handler.accept(arg);
}
}
Expand Down
15 changes: 5 additions & 10 deletions core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.calcite.util.NumberUtil;
import org.apache.calcite.util.TimeWithTimeZoneString;
import org.apache.calcite.util.TimestampWithTimeZoneString;
import org.apache.calcite.util.TryThreadLocal;
import org.apache.calcite.util.Unsafe;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.format.FormatElement;
Expand Down Expand Up @@ -213,8 +214,8 @@ public class SqlFunctions {
* <p>This is a straw man of an implementation whose main goal is to prove
* that sequences can be parsed, validated and planned. A real application
* will want persistent values for sequences, shared among threads. */
private static final ThreadLocal<@Nullable Map<String, AtomicLong>> THREAD_SEQUENCES =
ThreadLocal.withInitial(HashMap::new);
private static final TryThreadLocal<Map<String, AtomicLong>> THREAD_SEQUENCES =
TryThreadLocal.withInitial(HashMap::new);

/** A byte string consisting of a single byte that is the ASCII space
* character (0x20). */
Expand Down Expand Up @@ -5597,14 +5598,8 @@ public static long sequenceNextValue(String key) {
}

private static AtomicLong getAtomicLong(String key) {
final Map<String, AtomicLong> map =
requireNonNull(THREAD_SEQUENCES.get(), "THREAD_SEQUENCES.get()");
AtomicLong atomic = map.get(key);
if (atomic == null) {
atomic = new AtomicLong();
map.put(key, atomic);
}
return atomic;
final Map<String, AtomicLong> map = THREAD_SEQUENCES.get();
return map.computeIfAbsent(key, key_ -> new AtomicLong());
}

/** Support the ARRAYS_OVERLAP function. */
Expand Down
30 changes: 17 additions & 13 deletions core/src/main/java/org/apache/calcite/runtime/XmlFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.calcite.runtime;

import org.apache.calcite.util.SimpleNamespaceContext;
import org.apache.calcite.util.TryThreadLocal;

import org.apache.commons.lang3.StringUtils;

Expand Down Expand Up @@ -66,8 +67,8 @@
*/
public class XmlFunctions {

private static final ThreadLocal<@Nullable XPathFactory> XPATH_FACTORY =
ThreadLocal.withInitial(() -> {
private static final TryThreadLocal<XPathFactory> XPATH_FACTORY =
TryThreadLocal.withInitial(() -> {
final XPathFactory xPathFactory = XPathFactory.newInstance();
try {
xPathFactory.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true);
Expand All @@ -76,8 +77,9 @@ public class XmlFunctions {
}
return xPathFactory;
});
private static final ThreadLocal<@Nullable TransformerFactory> TRANSFORMER_FACTORY =
ThreadLocal.withInitial(() -> {

private static final TryThreadLocal<TransformerFactory> TRANSFORMER_FACTORY =
TryThreadLocal.withInitial(() -> {
final TransformerFactory transformerFactory = TransformerFactory.newInstance();
transformerFactory.setErrorListener(new InternalErrorListener());
try {
Expand All @@ -87,8 +89,9 @@ public class XmlFunctions {
}
return transformerFactory;
});
private static final ThreadLocal<@Nullable DocumentBuilderFactory> DOCUMENT_BUILDER_FACTORY =
ThreadLocal.withInitial(() -> {

private static final TryThreadLocal<DocumentBuilderFactory> DOCUMENT_BUILDER_FACTORY =
TryThreadLocal.withInitial(() -> {
final DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance();
documentBuilderFactory.setXIncludeAware(false);
documentBuilderFactory.setExpandEntityReferences(false);
Expand Down Expand Up @@ -117,7 +120,8 @@ private XmlFunctions() {
}
try {
final Node documentNode = getDocumentNode(input);
XPathExpression xpathExpression = castNonNull(XPATH_FACTORY.get()).newXPath().compile(xpath);
XPathExpression xpathExpression =
XPATH_FACTORY.get().newXPath().compile(xpath);
try {
NodeList nodes = (NodeList) xpathExpression
.evaluate(documentNode, XPathConstants.NODESET);
Expand Down Expand Up @@ -145,8 +149,8 @@ private XmlFunctions() {
try {
final Source xsltSource = new StreamSource(new StringReader(xslt));
final Source xmlSource = new StreamSource(new StringReader(xml));
final Transformer transformer = castNonNull(TRANSFORMER_FACTORY.get())
.newTransformer(xsltSource);
final Transformer transformer =
TRANSFORMER_FACTORY.get().newTransformer(xsltSource);
final StringWriter writer = new StringWriter();
final StreamResult result = new StreamResult(writer);
transformer.setErrorListener(new InternalErrorListener());
Expand All @@ -169,7 +173,7 @@ private XmlFunctions() {
return null;
}
try {
XPath xPath = castNonNull(XPATH_FACTORY.get()).newXPath();
XPath xPath = XPATH_FACTORY.get().newXPath();

if (namespace != null) {
xPath.setNamespaceContext(extractNamespaceContext(namespace));
Expand Down Expand Up @@ -206,7 +210,7 @@ private XmlFunctions() {
return null;
}
try {
XPath xPath = castNonNull(XPATH_FACTORY.get()).newXPath();
XPath xPath = XPATH_FACTORY.get().newXPath();
if (namespace != null) {
xPath.setNamespaceContext(extractNamespaceContext(namespace));
}
Expand Down Expand Up @@ -247,7 +251,7 @@ private static SimpleNamespaceContext extractNamespaceContext(String namespace)

private static String convertNodeToString(Node node) throws TransformerException {
StringWriter writer = new StringWriter();
Transformer transformer = castNonNull(TRANSFORMER_FACTORY.get()).newTransformer();
Transformer transformer = TRANSFORMER_FACTORY.get().newTransformer();
transformer.setErrorListener(new InternalErrorListener());
transformer.setOutputProperty(OutputKeys.OMIT_XML_DECLARATION, "yes");
transformer.transform(new DOMSource(node), new StreamResult(writer));
Expand All @@ -257,7 +261,7 @@ private static String convertNodeToString(Node node) throws TransformerException
private static Node getDocumentNode(final String xml) {
try {
final DocumentBuilder documentBuilder =
castNonNull(DOCUMENT_BUILDER_FACTORY.get()).newDocumentBuilder();
DOCUMENT_BUILDER_FACTORY.get().newDocumentBuilder();
final InputSource inputSource = new InputSource(new StringReader(xml));
return documentBuilder.parse(inputSource);
} catch (final ParserConfigurationException | SAXException | IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.calcite.util.TimeWithTimeZoneString;
import org.apache.calcite.util.TimestampString;
import org.apache.calcite.util.TimestampWithTimeZoneString;
import org.apache.calcite.util.TryThreadLocal;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.trace.CalciteTrace;

Expand Down Expand Up @@ -1214,11 +1215,11 @@ private OldTokenSequenceImpl(List<@Nullable Object> list) {
/** Pre-initialized {@link DateFormat} objects, to be used within the current
* thread, because {@code DateFormat} is not thread-safe. */
private static class Format {
private static final ThreadLocal<@Nullable Format> PER_THREAD =
ThreadLocal.withInitial(Format::new);
private static final TryThreadLocal<Format> PER_THREAD =
TryThreadLocal.withInitial(Format::new);

private static Format get() {
return requireNonNull(PER_THREAD.get(), "PER_THREAD.get()");
return PER_THREAD.get();
}

final DateFormat timestamp =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
*/
package org.apache.calcite.sql.type;

import org.apache.calcite.util.TryThreadLocal;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -78,8 +78,7 @@ public class SqlTypeCoercionRule implements SqlTypeMappingRule {

private static final SqlTypeCoercionRule LENIENT_INSTANCE;

public static final ThreadLocal<@Nullable SqlTypeCoercionRule> THREAD_PROVIDERS =
ThreadLocal.withInitial(() -> SqlTypeCoercionRule.INSTANCE);
public static final TryThreadLocal<SqlTypeCoercionRule> THREAD_PROVIDERS;

//~ Instance fields --------------------------------------------------------

Expand Down Expand Up @@ -352,6 +351,7 @@ private SqlTypeCoercionRule(Map<SqlTypeName, ImmutableSet<SqlTypeName>> map) {
.build());

LENIENT_INSTANCE = new SqlTypeCoercionRule(coerceRules.map);
THREAD_PROVIDERS = TryThreadLocal.of(SqlTypeCoercionRule.INSTANCE);
}

//~ Methods ----------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,8 @@ protected SqlValidatorImpl(

if (config.conformance().allowLenientCoercion()) {
final SqlTypeCoercionRule rules =
requireNonNull(
config.typeCoercionRules() != null
? config.typeCoercionRules()
: SqlTypeCoercionRule.THREAD_PROVIDERS.get(),
"rules");
first(config.typeCoercionRules(),
SqlTypeCoercionRule.instance());

final ImmutableSet<SqlTypeName> arrayMapping =
ImmutableSet.<SqlTypeName>builder()
Expand Down
Loading

0 comments on commit 740f2ee

Please sign in to comment.