Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various test fixes for EMQ X #73

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion interoperability/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def connectionLost(self, cause):
logging.info("connectionLost %s", str(cause))

def publishArrived(self, topicName, payload, qos, retained, msgid):
if topicName.startswith("$SYS/"):
return True
logging.info("publishArrived %s %s %d %d %d", topicName, payload, qos, retained, msgid)
self.messages.append((topicName, payload, qos, retained, msgid))
return True
Expand Down Expand Up @@ -163,7 +165,7 @@ def test_retained_messages(self):
time.sleep(1)
aclient.disconnect()

assert len(callback.messages) == 3
self.assertEqual(len(callback.messages), 3)

# clear retained messages
callback.clear()
Expand Down Expand Up @@ -448,6 +450,8 @@ def test_unsubscribe(self):
host = a
elif o in ("-p", "--port"):
port = int(a)
sys.argv.remove("-p") if "-p" in sys.argv else sys.argv.remove("--port")
sys.argv.remove(a)
elif o in ("--iterations"):
iterations = int(a)
else:
Expand Down
18 changes: 12 additions & 6 deletions interoperability/client_test5.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def connectionLost(self, cause):
logging.info("connectionLost %s" % str(cause))

def publishArrived(self, topicName, payload, qos, retained, msgid, properties=None):
if topicName.startswith("$SYS/"):
return True
logging.info("publishArrived %s %s %d %s %d %s", topicName, payload, qos, retained, msgid, str(properties))
self.messages.append((topicName, payload, qos, retained, msgid, properties))
self.messagedicts.append({"topicname" : topicName, "payload" : payload,
Expand All @@ -69,10 +71,12 @@ def cleanRetained():
curclient = mqtt_client.Client("clean retained".encode("utf-8"))
curclient.registerCallback(callback)
curclient.connect(host=host, port=port, cleanstart=True)
curclient.subscribe(["#"], [MQTTV5.SubscribeOptions(0)])
# Not all brokers (EMQ X) allow us to subscribe to #, so subscribe to + and +/# to accomplish the same
curclient.subscribe(["+"], [MQTTV5.SubscribeOptions(0)])
curclient.subscribe(["+/#"], [MQTTV5.SubscribeOptions(0)])
time.sleep(2) # wait for all retained messages to arrive
for message in callback.messages:
logging.info("deleting retained message for topic", message[0])
logging.info("deleting retained message for topic %s", message[0])
curclient.publish(message[0], b"", 0, retained=True)
curclient.disconnect()
time.sleep(.1)
Expand Down Expand Up @@ -339,7 +343,8 @@ def test_subscribe_failure(self):
time.sleep(1)
# subscribeds is a list of (msgid, [qos])
logging.info(callback.subscribeds)
assert callback.subscribeds[0][1][0].value == 0x80, "return code should be 0x80 %s" % callback.subscribeds
self.assertEqual(callback.subscribeds[0][1][0].value, 0x80,
"return code should be 0x80 %s" % callback.subscribeds)
except:
traceback.print_exc()
succeeded = False
Expand Down Expand Up @@ -551,7 +556,7 @@ def test_subscribe_options(self):
aclient.subscribe([topics[0]], [MQTTV5.SubscribeOptions(2, noLocal=True)])
self.waitfor(callback.subscribeds, 1, 3)
bclient.subscribe([topics[0]], [MQTTV5.SubscribeOptions(2, noLocal=True)])
self.waitfor(callback.subscribeds, 1, 3)
self.waitfor(callback2.subscribeds, 1, 3)
aclient.publish(topics[0], b"noLocal test", 1, retained=False)

self.waitfor(callback2.messages, 1, 3)
Expand Down Expand Up @@ -651,6 +656,7 @@ def test_subscribe_identifiers(self):
sub_properties.clear()
sub_properties.SubscriptionIdentifier = 3
bclient.subscribe([topics[0]+"/#"], [MQTTV5.SubscribeOptions(2)], properties=sub_properties)
self.waitfor(callback2.subscribeds, 2, 3)

bclient.publish(topics[0], b"sub identifier test", 1, retained=False)

Expand All @@ -661,7 +667,7 @@ def test_subscribe_identifiers(self):

self.waitfor(callback2.messages, 1, 3)
self.assertEqual(len(callback2.messages), 1, callback2.messages)
expected_subsids = set([2, 3])
expected_subsids = {2, 3}
received_subsids = set(callback2.messages[0][5].SubscriptionIdentifier)
self.assertEqual(received_subsids, expected_subsids, received_subsids)
bclient.disconnect()
Expand All @@ -679,7 +685,7 @@ def test_request_response(self):
self.waitfor(callback.subscribeds, 1, 3)

bclient.subscribe([topics[0]], [MQTTV5.SubscribeOptions(2, noLocal=True)])
self.waitfor(callback.subscribeds, 1, 3)
self.waitfor(callback2.subscribeds, 1, 3)

publish_properties = MQTTV5.Properties(MQTTV5.PacketTypes.PUBLISH)
publish_properties.ResponseTopic = topics[0]
Expand Down
13 changes: 11 additions & 2 deletions interoperability/mqtt/formats/MQTTV5/MQTTV5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

"""

import logging, struct
import logging, traceback

logger = logging.getLogger('MQTT broker')

Expand Down Expand Up @@ -57,6 +57,10 @@ class PacketTypes:
# Dummy packet type for properties use - will delay only applies to will
WILLMESSAGE = 99

@staticmethod
def fromInt(i):
return list(filter(lambda x: getattr(PacketTypes, x) == i, dir(PacketTypes)))[0]


class Packets(object):

Expand Down Expand Up @@ -107,7 +111,9 @@ def __getName__(self, packetType, identifier):
assert identifier in self.names.keys(), identifier
names = self.names[identifier]
namelist = [name for name in names.keys() if packetType in names[name]]
assert len(namelist) == 1
if not len(namelist) == 1:
raise ValueError("Reason code %s (%s) invalid for packet type %s" % (identifier, list(names.keys()),
PacketTypes.fromInt(packetType)))
return namelist[0]

def getId(self, name):
Expand Down Expand Up @@ -144,6 +150,9 @@ def json(self):
def pack(self):
return bytes([self.value])

def __repr__(self):
return str(self)

def __init__(self, packetType, aName="Success", identifier=-1):
self.packetType = packetType
self.names = {
Expand Down