diff --git a/dsmr_parser/clients/protocol.py b/dsmr_parser/clients/protocol.py index e5e6a66..9b4536e 100644 --- a/dsmr_parser/clients/protocol.py +++ b/dsmr_parser/clients/protocol.py @@ -91,7 +91,7 @@ class DSMRProtocol(asyncio.Protocol): self.transport = transport self.log.debug('connected') self._active = False - if self._keep_alive_interval: + if self.loop and self._keep_alive_interval: self.loop.call_later(self._keep_alive_interval, self.keep_alive) def data_received(self, data): @@ -108,10 +108,12 @@ class DSMRProtocol(asyncio.Protocol): if self._active: self.log.debug('keep-alive checked') self._active = False - self.loop.call_later(self._keep_alive_interval, self.keep_alive) + if self.loop: + self.loop.call_later(self._keep_alive_interval, self.keep_alive) else: self.log.warning('keep-alive check failed') - self.transport.close() + if self.transport: + self.transport.close() def connection_lost(self, exc): """Stop when connection is lost.""" diff --git a/test/test_protocol.py b/test/test_protocol.py index 2fb14e0..c298d5c 100644 --- a/test/test_protocol.py +++ b/test/test_protocol.py @@ -5,7 +5,7 @@ import unittest from dsmr_parser import obis_references as obis from dsmr_parser import telegram_specifications from dsmr_parser.parsers import TelegramParser -from dsmr_parser.clients.protocol import DSMRProtocol +from dsmr_parser.clients.protocol import create_dsmr_protocol TELEGRAM_V2_2 = ( @@ -35,9 +35,10 @@ TELEGRAM_V2_2 = ( class ProtocolTest(unittest.TestCase): def setUp(self): - telegram_parser = TelegramParser(telegram_specifications.V2_2) - self.protocol = DSMRProtocol(None, telegram_parser, - telegram_callback=Mock()) + new_protocol, _ = create_dsmr_protocol('2.2', + telegram_callback=Mock(), + keep_alive_interval=1) + self.protocol = new_protocol() def test_complete_packet(self): """Protocol should assemble incoming lines into complete packet.""" @@ -52,3 +53,23 @@ class ProtocolTest(unittest.TestCase): assert float(telegram[obis.GAS_METER_READING].value) == 1.001 assert telegram[obis.GAS_METER_READING].unit == 'm3' + + def test_receive_packet(self): + """Protocol packet reception.""" + + mock_transport = Mock() + self.protocol.connection_made(mock_transport) + assert not self.protocol._active + + self.protocol.data_received(TELEGRAM_V2_2.encode('ascii')) + assert self.protocol._active + + # 1st call of keep_alive resets 'active' flag + self.protocol.keep_alive() + assert not self.protocol._active + + # 2nd call of keep_alive should close the transport + self.protocol.keep_alive() + assert mock_transport.close.called_once() + + self.protocol.connection_lost(None)