123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596 |
- import asyncio
- from contextlib import contextmanager
- import os
- import socket
- from tempfile import TemporaryDirectory
- import avocado
- from qemu.qmp import ConnectError, Runstate
- from qemu.qmp.protocol import AsyncProtocol, StateError
- from qemu.qmp.util import asyncio_run, create_task
- class NullProtocol(AsyncProtocol[None]):
- """
- NullProtocol is a test mockup of an AsyncProtocol implementation.
- It adds a fake_session instance variable that enables a code path
- that bypasses the actual connection logic, but still allows the
- reader/writers to start.
- Because the message type is defined as None, an asyncio.Event named
- 'trigger_input' is created that prohibits the reader from
- incessantly being able to yield None; this event can be poked to
- simulate an incoming message.
- For testing symmetry with do_recv, an interface is added to "send" a
- Null message.
- For testing purposes, a "simulate_disconnection" method is also
- added which allows us to trigger a bottom half disconnect without
- injecting any real errors into the reader/writer loops; in essence
- it performs exactly half of what disconnect() normally does.
- """
- def __init__(self, name=None):
- self.fake_session = False
- self.trigger_input: asyncio.Event
- super().__init__(name)
- async def _establish_session(self):
- self.trigger_input = asyncio.Event()
- await super()._establish_session()
- async def _do_start_server(self, address, ssl=None):
- if self.fake_session:
- self._accepted = asyncio.Event()
- self._set_state(Runstate.CONNECTING)
- await asyncio.sleep(0)
- else:
- await super()._do_start_server(address, ssl)
- async def _do_accept(self):
- if self.fake_session:
- self._accepted = None
- else:
- await super()._do_accept()
- async def _do_connect(self, address, ssl=None):
- if self.fake_session:
- self._set_state(Runstate.CONNECTING)
- await asyncio.sleep(0)
- else:
- await super()._do_connect(address, ssl)
- async def _do_recv(self) -> None:
- await self.trigger_input.wait()
- self.trigger_input.clear()
- def _do_send(self, msg: None) -> None:
- pass
- async def send_msg(self) -> None:
- await self._outgoing.put(None)
- async def simulate_disconnect(self) -> None:
- """
- Simulates a bottom-half disconnect.
- This method schedules a disconnection but does not wait for it
- to complete. This is used to put the loop into the DISCONNECTING
- state without fully quiescing it back to IDLE. This is normally
- something you cannot coax AsyncProtocol to do on purpose, but it
- will be similar to what happens with an unhandled Exception in
- the reader/writer.
- Under normal circumstances, the library design requires you to
- await on disconnect(), which awaits the disconnect task and
- returns bottom half errors as a pre-condition to allowing the
- loop to return back to IDLE.
- """
- self._schedule_disconnect()
- class LineProtocol(AsyncProtocol[str]):
- def __init__(self, name=None):
- super().__init__(name)
- self.rx_history = []
- async def _do_recv(self) -> str:
- raw = await self._readline()
- msg = raw.decode()
- self.rx_history.append(msg)
- return msg
- def _do_send(self, msg: str) -> None:
- assert self._writer is not None
- self._writer.write(msg.encode() + b'\n')
- async def send_msg(self, msg: str) -> None:
- await self._outgoing.put(msg)
- def run_as_task(coro, allow_cancellation=False):
- """
- Run a given coroutine as a task.
- Optionally, wrap it in a try..except block that allows this
- coroutine to be canceled gracefully.
- """
- async def _runner():
- try:
- await coro
- except asyncio.CancelledError:
- if allow_cancellation:
- return
- raise
- return create_task(_runner())
- @contextmanager
- def jammed_socket():
- """
- Opens up a random unused TCP port on localhost, then jams it.
- """
- socks = []
- try:
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.bind(('127.0.0.1', 0))
- sock.listen(1)
- address = sock.getsockname()
- socks.append(sock)
- # I don't *fully* understand why, but it takes *two* un-accepted
- # connections to start jamming the socket.
- for _ in range(2):
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.connect(address)
- socks.append(sock)
- yield address
- finally:
- for sock in socks:
- sock.close()
- class Smoke(avocado.Test):
- def setUp(self):
- self.proto = NullProtocol()
- def test__repr__(self):
- self.assertEqual(
- repr(self.proto),
- "<NullProtocol runstate=IDLE>"
- )
- def testRunstate(self):
- self.assertEqual(
- self.proto.runstate,
- Runstate.IDLE
- )
- def testDefaultName(self):
- self.assertEqual(
- self.proto.name,
- None
- )
- def testLogger(self):
- self.assertEqual(
- self.proto.logger.name,
- 'qemu.qmp.protocol'
- )
- def testName(self):
- self.proto = NullProtocol('Steve')
- self.assertEqual(
- self.proto.name,
- 'Steve'
- )
- self.assertEqual(
- self.proto.logger.name,
- 'qemu.qmp.protocol.Steve'
- )
- self.assertEqual(
- repr(self.proto),
- "<NullProtocol name='Steve' runstate=IDLE>"
- )
- class TestBase(avocado.Test):
- def setUp(self):
- self.proto = NullProtocol(type(self).__name__)
- self.assertEqual(self.proto.runstate, Runstate.IDLE)
- self.runstate_watcher = None
- def tearDown(self):
- self.assertEqual(self.proto.runstate, Runstate.IDLE)
- async def _asyncSetUp(self):
- pass
- async def _asyncTearDown(self):
- if self.runstate_watcher:
- await self.runstate_watcher
- @staticmethod
- def async_test(async_test_method):
- """
- Decorator; adds SetUp and TearDown to async tests.
- """
- async def _wrapper(self, *args, **kwargs):
- loop = asyncio.get_event_loop()
- loop.set_debug(True)
- await self._asyncSetUp()
- await async_test_method(self, *args, **kwargs)
- await self._asyncTearDown()
- return _wrapper
- # Definitions
- # The states we expect a "bad" connect/accept attempt to transition through
- BAD_CONNECTION_STATES = (
- Runstate.CONNECTING,
- Runstate.DISCONNECTING,
- Runstate.IDLE,
- )
- # The states we expect a "good" session to transition through
- GOOD_CONNECTION_STATES = (
- Runstate.CONNECTING,
- Runstate.RUNNING,
- Runstate.DISCONNECTING,
- Runstate.IDLE,
- )
- # Helpers
- async def _watch_runstates(self, *states):
- """
- This launches a task alongside (most) tests below to confirm that
- the sequence of runstate changes that occur is exactly as
- anticipated.
- """
- async def _watcher():
- for state in states:
- new_state = await self.proto.runstate_changed()
- self.assertEqual(
- new_state,
- state,
- msg=f"Expected state '{state.name}'",
- )
- self.runstate_watcher = create_task(_watcher())
- # Kick the loop and force the task to block on the event.
- await asyncio.sleep(0)
- class State(TestBase):
- @TestBase.async_test
- async def testSuperfluousDisconnect(self):
- """
- Test calling disconnect() while already disconnected.
- """
- await self._watch_runstates(
- Runstate.DISCONNECTING,
- Runstate.IDLE,
- )
- await self.proto.disconnect()
- class Connect(TestBase):
- """
- Tests primarily related to calling Connect().
- """
- async def _bad_connection(self, family: str):
- assert family in ('INET', 'UNIX')
- if family == 'INET':
- await self.proto.connect(('127.0.0.1', 0))
- elif family == 'UNIX':
- await self.proto.connect('/dev/null')
- async def _hanging_connection(self):
- with jammed_socket() as addr:
- await self.proto.connect(addr)
- async def _bad_connection_test(self, family: str):
- await self._watch_runstates(*self.BAD_CONNECTION_STATES)
- with self.assertRaises(ConnectError) as context:
- await self._bad_connection(family)
- self.assertIsInstance(context.exception.exc, OSError)
- self.assertEqual(
- context.exception.error_message,
- "Failed to establish connection"
- )
- @TestBase.async_test
- async def testBadINET(self):
- """
- Test an immediately rejected call to an IP target.
- """
- await self._bad_connection_test('INET')
- @TestBase.async_test
- async def testBadUNIX(self):
- """
- Test an immediately rejected call to a UNIX socket target.
- """
- await self._bad_connection_test('UNIX')
- @TestBase.async_test
- async def testCancellation(self):
- """
- Test what happens when a connection attempt is aborted.
- """
- # Note that accept() cannot be cancelled outright, as it isn't a task.
- # However, we can wrap it in a task and cancel *that*.
- await self._watch_runstates(*self.BAD_CONNECTION_STATES)
- task = run_as_task(self._hanging_connection(), allow_cancellation=True)
- state = await self.proto.runstate_changed()
- self.assertEqual(state, Runstate.CONNECTING)
- # This is insider baseball, but the connection attempt has
- # yielded *just* before the actual connection attempt, so kick
- # the loop to make sure it's truly wedged.
- await asyncio.sleep(0)
- task.cancel()
- await task
- @TestBase.async_test
- async def testTimeout(self):
- """
- Test what happens when a connection attempt times out.
- """
- await self._watch_runstates(*self.BAD_CONNECTION_STATES)
- task = run_as_task(self._hanging_connection())
- # More insider baseball: to improve the speed of this test while
- # guaranteeing that the connection even gets a chance to start,
- # verify that the connection hangs *first*, then await the
- # result of the task with a nearly-zero timeout.
- state = await self.proto.runstate_changed()
- self.assertEqual(state, Runstate.CONNECTING)
- await asyncio.sleep(0)
- with self.assertRaises(asyncio.TimeoutError):
- await asyncio.wait_for(task, timeout=0)
- @TestBase.async_test
- async def testRequire(self):
- """
- Test what happens when a connection attempt is made while CONNECTING.
- """
- await self._watch_runstates(*self.BAD_CONNECTION_STATES)
- task = run_as_task(self._hanging_connection(), allow_cancellation=True)
- state = await self.proto.runstate_changed()
- self.assertEqual(state, Runstate.CONNECTING)
- with self.assertRaises(StateError) as context:
- await self._bad_connection('UNIX')
- self.assertEqual(
- context.exception.error_message,
- "NullProtocol is currently connecting."
- )
- self.assertEqual(context.exception.state, Runstate.CONNECTING)
- self.assertEqual(context.exception.required, Runstate.IDLE)
- task.cancel()
- await task
- @TestBase.async_test
- async def testImplicitRunstateInit(self):
- """
- Test what happens if we do not wait on the runstate event until
- AFTER a connection is made, i.e., connect()/accept() themselves
- initialize the runstate event. All of the above tests force the
- initialization by waiting on the runstate *first*.
- """
- task = run_as_task(self._hanging_connection(), allow_cancellation=True)
- # Kick the loop to coerce the state change
- await asyncio.sleep(0)
- assert self.proto.runstate == Runstate.CONNECTING
- # We already missed the transition to CONNECTING
- await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
- task.cancel()
- await task
- class Accept(Connect):
- """
- All of the same tests as Connect, but using the accept() interface.
- """
- async def _bad_connection(self, family: str):
- assert family in ('INET', 'UNIX')
- if family == 'INET':
- await self.proto.start_server_and_accept(('example.com', 1))
- elif family == 'UNIX':
- await self.proto.start_server_and_accept('/dev/null')
- async def _hanging_connection(self):
- with TemporaryDirectory(suffix='.qmp') as tmpdir:
- sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
- await self.proto.start_server_and_accept(sock)
- class FakeSession(TestBase):
- def setUp(self):
- super().setUp()
- self.proto.fake_session = True
- async def _asyncSetUp(self):
- await super()._asyncSetUp()
- await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
- async def _asyncTearDown(self):
- await self.proto.disconnect()
- await super()._asyncTearDown()
- ####
- @TestBase.async_test
- async def testFakeConnect(self):
- """Test the full state lifecycle (via connect) with a no-op session."""
- await self.proto.connect('/not/a/real/path')
- self.assertEqual(self.proto.runstate, Runstate.RUNNING)
- @TestBase.async_test
- async def testFakeAccept(self):
- """Test the full state lifecycle (via accept) with a no-op session."""
- await self.proto.start_server_and_accept('/not/a/real/path')
- self.assertEqual(self.proto.runstate, Runstate.RUNNING)
- @TestBase.async_test
- async def testFakeRecv(self):
- """Test receiving a fake/null message."""
- await self.proto.start_server_and_accept('/not/a/real/path')
- logname = self.proto.logger.name
- with self.assertLogs(logname, level='DEBUG') as context:
- self.proto.trigger_input.set()
- self.proto.trigger_input.clear()
- await asyncio.sleep(0) # Kick reader.
- self.assertEqual(
- context.output,
- [f"DEBUG:{logname}:<-- None"],
- )
- @TestBase.async_test
- async def testFakeSend(self):
- """Test sending a fake/null message."""
- await self.proto.start_server_and_accept('/not/a/real/path')
- logname = self.proto.logger.name
- with self.assertLogs(logname, level='DEBUG') as context:
- # Cheat: Send a Null message to nobody.
- await self.proto.send_msg()
- # Kick writer; awaiting on a queue.put isn't sufficient to yield.
- await asyncio.sleep(0)
- self.assertEqual(
- context.output,
- [f"DEBUG:{logname}:--> None"],
- )
- async def _prod_session_api(
- self,
- current_state: Runstate,
- error_message: str,
- accept: bool = True
- ):
- with self.assertRaises(StateError) as context:
- if accept:
- await self.proto.start_server_and_accept('/not/a/real/path')
- else:
- await self.proto.connect('/not/a/real/path')
- self.assertEqual(context.exception.error_message, error_message)
- self.assertEqual(context.exception.state, current_state)
- self.assertEqual(context.exception.required, Runstate.IDLE)
- @TestBase.async_test
- async def testAcceptRequireRunning(self):
- """Test that accept() cannot be called when Runstate=RUNNING"""
- await self.proto.start_server_and_accept('/not/a/real/path')
- await self._prod_session_api(
- Runstate.RUNNING,
- "NullProtocol is already connected and running.",
- accept=True,
- )
- @TestBase.async_test
- async def testConnectRequireRunning(self):
- """Test that connect() cannot be called when Runstate=RUNNING"""
- await self.proto.start_server_and_accept('/not/a/real/path')
- await self._prod_session_api(
- Runstate.RUNNING,
- "NullProtocol is already connected and running.",
- accept=False,
- )
- @TestBase.async_test
- async def testAcceptRequireDisconnecting(self):
- """Test that accept() cannot be called when Runstate=DISCONNECTING"""
- await self.proto.start_server_and_accept('/not/a/real/path')
- # Cheat: force a disconnect.
- await self.proto.simulate_disconnect()
- await self._prod_session_api(
- Runstate.DISCONNECTING,
- ("NullProtocol is disconnecting."
- " Call disconnect() to return to IDLE state."),
- accept=True,
- )
- @TestBase.async_test
- async def testConnectRequireDisconnecting(self):
- """Test that connect() cannot be called when Runstate=DISCONNECTING"""
- await self.proto.start_server_and_accept('/not/a/real/path')
- # Cheat: force a disconnect.
- await self.proto.simulate_disconnect()
- await self._prod_session_api(
- Runstate.DISCONNECTING,
- ("NullProtocol is disconnecting."
- " Call disconnect() to return to IDLE state."),
- accept=False,
- )
- class SimpleSession(TestBase):
- def setUp(self):
- super().setUp()
- self.server = LineProtocol(type(self).__name__ + '-server')
- async def _asyncSetUp(self):
- await super()._asyncSetUp()
- await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
- async def _asyncTearDown(self):
- await self.proto.disconnect()
- try:
- await self.server.disconnect()
- except EOFError:
- pass
- await super()._asyncTearDown()
- @TestBase.async_test
- async def testSmoke(self):
- with TemporaryDirectory(suffix='.qmp') as tmpdir:
- sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
- server_task = create_task(self.server.start_server_and_accept(sock))
- # give the server a chance to start listening [...]
- await asyncio.sleep(0)
- await self.proto.connect(sock)
|