Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support mTLS authentication for MoP #1414

Merged
merged 23 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ jobs:
run: ./scripts/retry.sh mvn -B -ntp test -Dtest=${{ matrix.test }} -DfailIfNoTests=false

- name: Upload jacoco artifact
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: ${{ matrix.test }}-jacoco-artifact
path: '**/*.exec'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.pulsar.broker.authentication.AuthenticationDataSource;

Expand All @@ -69,7 +70,8 @@ public class Connection {
@Getter
private final TopicSubscriptionManager topicSubscriptionManager;
@Getter
private final MqttConnectMessage connectMessage;
@Setter
private MqttConnectMessage connectMessage;
@Getter
private final ClientRestrictions clientRestrictions;
@Getter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public final class Constants {
public static final String AUTH_BASIC = "basic";
public static final String AUTH_TOKEN = "token";

public static final String AUTH_MTLS = "mTls";

public static final String ATTR_TOPIC_SUBS = "topicSubs";

public static final String MQTT_PROPERTIES = "MQTT_PROPERTIES_%d_";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,25 @@
package io.streamnative.pulsar.handlers.mqtt;

import static io.streamnative.pulsar.handlers.mqtt.Constants.AUTH_BASIC;
import static io.streamnative.pulsar.handlers.mqtt.Constants.AUTH_MTLS;
import static io.streamnative.pulsar.handlers.mqtt.Constants.AUTH_TOKEN;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
import io.netty.handler.codec.mqtt.MqttConnectPayload;
import io.streamnative.pulsar.handlers.mqtt.identitypool.AuthenticationProviderMTls;
import io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.naming.AuthenticationException;
import javax.net.ssl.SSLSession;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.pulsar.broker.authentication.AuthenticationDataCommand;
import org.apache.pulsar.broker.authentication.AuthenticationDataSource;
import org.apache.pulsar.broker.authentication.AuthenticationProvider;
import org.apache.pulsar.broker.authentication.AuthenticationService;
import org.apache.pulsar.broker.service.BrokerService;

/**
* MQTT authentication service.
Expand All @@ -42,8 +46,11 @@ public class MQTTAuthenticationService {
@Getter
private final Map<String, AuthenticationProvider> authenticationProviders;

public MQTTAuthenticationService(AuthenticationService authenticationService, List<String> authenticationMethods) {
this.authenticationService = authenticationService;
private final BrokerService brokerService;

public MQTTAuthenticationService(BrokerService brokerService, List<String> authenticationMethods) {
this.brokerService = brokerService;
this.authenticationService = brokerService.getAuthenticationService();
this.authenticationProviders = getAuthenticationProviders(authenticationMethods);
}

Expand All @@ -53,6 +60,14 @@ private Map<String, AuthenticationProvider> getAuthenticationProviders(List<Stri
final AuthenticationProvider authProvider = authenticationService.getAuthenticationProvider(method);
if (authProvider != null) {
providers.put(method, authProvider);
} else if (AUTH_MTLS.equalsIgnoreCase(method)) {
AuthenticationProviderMTls providerMTls = new AuthenticationProviderMTls();
try {
providerMTls.initialize(brokerService.pulsar().getLocalMetadataStore());
providers.put(method, providerMTls);
} catch (Exception e) {
log.error("Failed to initialize MQTT authentication method {} ", method, e);
}
} else {
log.error("MQTT authentication method {} is not enabled in Pulsar configuration!", method);
}
Expand All @@ -64,27 +79,32 @@ private Map<String, AuthenticationProvider> getAuthenticationProviders(List<Stri
return providers;
}

public AuthenticationResult authenticate(MqttConnectMessage connectMessage) {
public AuthenticationResult authenticate(boolean fromProxy,
SSLSession session, MqttConnectMessage connectMessage) {
String authMethod = MqttMessageUtils.getAuthMethod(connectMessage);
if (authMethod != null) {
byte[] authData = MqttMessageUtils.getAuthData(connectMessage);
if (authData == null) {
return AuthenticationResult.FAILED;
}
if (fromProxy && AUTH_MTLS.equalsIgnoreCase(authMethod)) {
return new AuthenticationResult(true, new String(authData),
new AuthenticationDataCommand(new String(authData), null, session));
}
return authenticate(connectMessage.payload().clientIdentifier(), authMethod,
new AuthenticationDataCommand(new String(authData)));
new AuthenticationDataCommand(new String(authData), null, session));
}
return authenticate(connectMessage.payload());
return authenticate(connectMessage.payload(), session);
}

public AuthenticationResult authenticate(MqttConnectPayload payload) {
public AuthenticationResult authenticate(MqttConnectPayload payload, SSLSession session) {
String userRole = null;
boolean authenticated = false;
AuthenticationDataSource authenticationDataSource = null;
for (Map.Entry<String, AuthenticationProvider> entry : authenticationProviders.entrySet()) {
String authMethod = entry.getKey();
try {
AuthenticationDataSource authData = getAuthData(authMethod, payload);
AuthenticationDataSource authData = getAuthData(authMethod, payload, session);
userRole = entry.getValue().authenticate(authData);
authenticated = true;
authenticationDataSource = authData;
Expand Down Expand Up @@ -116,12 +136,14 @@ public AuthenticationResult authenticate(String clientIdentifier,
return new AuthenticationResult(authenticated, userRole, command);
}

public AuthenticationDataSource getAuthData(String authMethod, MqttConnectPayload payload) {
public AuthenticationDataSource getAuthData(String authMethod, MqttConnectPayload payload, SSLSession session) {
switch (authMethod) {
case AUTH_BASIC:
return new AuthenticationDataCommand(payload.userName() + ":" + payload.password());
case AUTH_TOKEN:
return new AuthenticationDataCommand(payload.password());
case AUTH_MTLS:
return new AuthenticationDataCommand(null, null, session);
default:
throw new IllegalArgumentException(
String.format("Unsupported authentication method : %s!", authMethod));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ public class MQTTCommonConfiguration extends ServiceConfiguration {
)
private boolean mqttProxyTlsEnabled = false;

@FieldContext(
category = CATEGORY_MQTT_PROXY,
required = false,
doc = "Whether start mqtt protocol handler with proxy mtls"
coderzc marked this conversation as resolved.
Show resolved Hide resolved
)
private boolean mqttProxyMTlsAuthenticationEnabled = false;

@FieldContext(
category = CATEGORY_MQTT_PROXY,
required = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public MQTTService(BrokerService brokerService, MQTTServerConfiguration serverCo
this.metricsProvider = new MQTTMetricsProvider(metricsCollector);
this.pulsarService.addPrometheusRawMetricsProvider(metricsProvider);
this.authenticationService = serverConfiguration.isMqttAuthenticationEnabled()
? new MQTTAuthenticationService(brokerService.getAuthenticationService(),
? new MQTTAuthenticationService(brokerService,
serverConfiguration.getMqttAuthenticationMethods()) : null;
this.connectionManager = new MQTTConnectionManager(pulsarService.getAdvertisedAddress());
this.subscriptionManager = new MQTTSubscriptionManager();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public CompletableFuture<Void> writeAndFlush(final MqttAdapterMessage adapterMsg
});
future.exceptionally(ex -> {
log.warn("[AdapterChannel][{}] Proxy write to broker {} failed."
+ " error message: {}", clientId, broker, ex.getMessage());
+ " adapterMsg message: {}", clientId, broker, adapterMsg, ex);
return null;
});
return future;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.streamnative.pulsar.handlers.mqtt.exception;

/**
* Internal server exception.
*/
public class MQTTAuthException extends Exception {

public MQTTAuthException() {
}

public MQTTAuthException(String message) {
super(message);
}

public MQTTAuthException(String message, Throwable cause) {
super(message, cause);
}

public MQTTAuthException(Throwable cause) {
super(cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
*/
package io.streamnative.pulsar.handlers.mqtt.identitypool;


import static io.streamnative.pulsar.handlers.mqtt.identitypool.ExpressionCompiler.DN;
import static io.streamnative.pulsar.handlers.mqtt.identitypool.ExpressionCompiler.DN_KEYS;
import static io.streamnative.pulsar.handlers.mqtt.identitypool.ExpressionCompiler.SAN;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
*/
package io.streamnative.pulsar.handlers.mqtt.proxy;

import static io.streamnative.pulsar.handlers.mqtt.Constants.AUTH_MTLS;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.createMqttConnectMessage;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.createMqttPublishMessage;
import static io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils.createMqttSubscribeMessage;
import com.google.common.collect.Lists;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.mqtt.MqttConnectMessage;
Expand Down Expand Up @@ -64,17 +68,19 @@
import org.apache.pulsar.common.naming.TopicName;
import org.apache.pulsar.common.util.Codec;
import org.apache.pulsar.common.util.FutureUtil;

/**
* Proxy inbound handler is the bridge between proxy and MoP.
*/
@Slf4j
public class MQTTProxyProtocolMethodProcessor extends AbstractCommonProtocolMethodProcessor {

private final PulsarService pulsarService;

@Getter
private Connection connection;
private final LookupHandler lookupHandler;
private final MQTTProxyConfiguration proxyConfig;
private final PulsarService pulsarService;
private final Map<String, CompletableFuture<AdapterChannel>> topicBrokers;
private final Map<InetSocketAddress, AdapterChannel> adapterChannels;
@Getter
Expand All @@ -86,6 +92,7 @@ public class MQTTProxyProtocolMethodProcessor extends AbstractCommonProtocolMeth
private final MQTTConnectionManager connectionManager;
private final SystemEventService eventService;
private final MQTTProxyAdapter proxyAdapter;

private final AtomicBoolean isDisconnected = new AtomicBoolean(false);
private final AutoSubscribeHandler autoSubscribeHandler;

Expand All @@ -95,8 +102,9 @@ public class MQTTProxyProtocolMethodProcessor extends AbstractCommonProtocolMeth

public MQTTProxyProtocolMethodProcessor(MQTTProxyService proxyService, ChannelHandlerContext ctx) {
super(proxyService.getAuthenticationService(),
proxyService.getProxyConfig().isMqttAuthenticationEnabled(), ctx);
this.pulsarService = proxyService.getPulsarService();
proxyService.getProxyConfig().isMqttAuthenticationEnabled(),
ctx);
pulsarService = proxyService.getPulsarService();
this.lookupHandler = proxyService.getLookupHandler();
this.proxyConfig = proxyService.getProxyConfig();
this.connectionManager = proxyService.getConnectionManager();
Expand All @@ -115,7 +123,7 @@ public MQTTProxyProtocolMethodProcessor(MQTTProxyService proxyService, ChannelHa
@Override
public void doProcessConnect(MqttAdapterMessage adapter, String userRole,
AuthenticationDataSource authData, ClientRestrictions clientRestrictions) {
final MqttConnectMessage msg = (MqttConnectMessage) adapter.getMqttMessage();
MqttConnectMessage msg = (MqttConnectMessage) adapter.getMqttMessage();
final ServerRestrictions serverRestrictions = ServerRestrictions.builder()
.receiveMaximum(proxyConfig.getReceiveMaximum())
.maximumPacketSize(proxyConfig.getMqttMessageMaxLength())
Expand All @@ -133,6 +141,12 @@ public void doProcessConnect(MqttAdapterMessage adapter, String userRole,
.processor(this)
.build();
connection.sendConnAck();
if (proxyConfig.isMqttProxyMTlsAuthenticationEnabled()) {
MqttConnectMessage connectMessage = createMqttConnectMessage(msg, AUTH_MTLS, userRole);
msg = connectMessage;
connection.setConnectMessage(msg);
}

ConnectEvent connectEvent = ConnectEvent.builder()
.clientId(connection.getClientId())
.address(pulsarService.getAdvertisedAddress())
Expand All @@ -152,6 +166,10 @@ public void processPublish(MqttAdapterMessage adapter) {
proxyConfig.getDefaultTenant(), proxyConfig.getDefaultNamespace(),
TopicDomain.getEnum(proxyConfig.getDefaultTopicDomain()));
adapter.setClientId(connection.getClientId());
if (proxyConfig.isMqttProxyMTlsAuthenticationEnabled()) {
MqttPublishMessage mqttMessage = createMqttPublishMessage(msg, AUTH_MTLS, connection.getUserRole());
adapter.setMqttMessage(mqttMessage);
}
startPublish()
.thenCompose(__ -> writeToBroker(pulsarTopicName, adapter))
.whenComplete((unused, ex) -> {
Expand Down Expand Up @@ -282,6 +300,10 @@ public void processSubscribe(final MqttAdapterMessage adapter) {
log.debug("[Proxy Subscribe] [{}] msg: {}", clientId, msg);
}
registerTopicListener(adapter);
if (proxyConfig.isMqttProxyMTlsAuthenticationEnabled()) {
MqttSubscribeMessage mqttMessage = createMqttSubscribeMessage(msg, AUTH_MTLS, connection.getUserRole());
adapter.setMqttMessage(mqttMessage);
}
doSubscribe(adapter, false)
.exceptionally(ex -> {
Throwable realCause = FutureUtil.unwrapCompletionException(ex);
Expand Down Expand Up @@ -447,8 +469,10 @@ private CompletableFuture<AdapterChannel> connectToBroker(final String topic) {
key -> lookupHandler.findBroker(TopicName.get(topic)).thenApply(mqttBroker ->
adapterChannels.computeIfAbsent(mqttBroker, key1 -> {
AdapterChannel adapterChannel = proxyAdapter.getAdapterChannel(mqttBroker);
final MqttConnectMessage connectMessage = connection.getConnectMessage();

adapterChannel.writeAndFlush(new MqttAdapterMessage(connection.getClientId(),
connection.getConnectMessage()));
connectMessage));
return adapterChannel;
})
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public void start() throws MQTTProxyException {
throw new MQTTProxyException(e);
}

if (proxyConfig.isMqttProxyTlsEnabled()) {
if (proxyConfig.isMqttProxyTlsEnabled() || proxyConfig.isMqttProxyMTlsAuthenticationEnabled()) {
ServerBootstrap tlsBootstrap = serverBootstrap.clone();
tlsBootstrap.childHandler(new MQTTProxyChannelInitializer(
this, proxyConfig, true, sslContextRefresher));
Expand Down Expand Up @@ -148,7 +148,6 @@ public void start() throws MQTTProxyException {
throw new MQTTProxyException(e);
}
}

this.lookupHandler = new PulsarServiceLookupHandler(pulsarService, proxyConfig);
this.eventService.start();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import io.netty.handler.codec.mqtt.MqttMessageBuilders;
import io.netty.handler.codec.mqtt.MqttProperties;
import io.netty.handler.codec.mqtt.MqttReasonCodeAndPropertiesVariableHeader;
import io.netty.handler.ssl.SslHandler;
import io.streamnative.pulsar.handlers.mqtt.Connection;
import io.streamnative.pulsar.handlers.mqtt.MQTTAuthenticationService;
import io.streamnative.pulsar.handlers.mqtt.ProtocolMethodProcessor;
import io.streamnative.pulsar.handlers.mqtt.adapter.MqttAdapterMessage;
import io.streamnative.pulsar.handlers.mqtt.exception.MQTTAuthException;
import io.streamnative.pulsar.handlers.mqtt.exception.restrictions.InvalidReceiveMaximumException;
import io.streamnative.pulsar.handlers.mqtt.messages.MqttPropertyUtils;
import io.streamnative.pulsar.handlers.mqtt.messages.ack.MqttConnectAck;
Expand All @@ -33,6 +35,7 @@
import io.streamnative.pulsar.handlers.mqtt.utils.MqttMessageUtils;
import io.streamnative.pulsar.handlers.mqtt.utils.MqttUtils;
import io.streamnative.pulsar.handlers.mqtt.utils.NettyUtils;
import javax.net.ssl.SSLSession;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -123,8 +126,10 @@ public void processConnect(MqttAdapterMessage adapter) {
clientId, username);
}
} else {
MQTTAuthenticationService.AuthenticationResult authResult = authenticationService
.authenticate(connectMessage);
MQTTAuthenticationService.AuthenticationResult authResult;
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
SSLSession session = (sslHandler != null) ? sslHandler.engine().getSession() : null;
authResult = authenticationService.authenticate(adapter.fromProxy(), session, connectMessage);
if (authResult.isFailed()) {
MqttMessage mqttMessage = MqttConnectAck.errorBuilder().authFail(protocolVersion);
log.error("[CONNECT] Invalid or incorrect authentication. CId={}, username={}", clientId, username);
Expand Down Expand Up @@ -157,6 +162,10 @@ public void processConnect(MqttAdapterMessage adapter) {
}
}

protected MQTTAuthenticationService.AuthenticationResult mtlsAuth(boolean fromProxy) throws MQTTAuthException {
return MQTTAuthenticationService.AuthenticationResult.FAILED;
}

Comment on lines +165 to +168
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unless code

@Override
public void processPubAck(MqttAdapterMessage msg) {
if (log.isDebugEnabled()) {
Expand Down
Loading
Loading