From 93f40dc7ddc91931b4a25439921f04368f36943b Mon Sep 17 00:00:00 2001 From: jpic Date: Fri, 12 Nov 2021 03:56:48 +0100 Subject: [PATCH] Add expect feature --- shlax/subprocess.py | 64 +++++++++++++++++++++------------------------ tests/test_proc.py | 21 ++++++++++++++- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/shlax/subprocess.py b/shlax/subprocess.py index f51fbd4..8fc3555 100644 --- a/shlax/subprocess.py +++ b/shlax/subprocess.py @@ -7,10 +7,10 @@ import sys from .colors import colors -class SubprocessProtocol(asyncio.SubprocessProtocol): - def __init__(self, proc): +class SubprocessProtocol(asyncio.subprocess.SubprocessStreamProtocol): + def __init__(self, proc, *args, **kwargs): self.proc = proc - self.output = bytearray() + super().__init__(*args, **kwargs) def pipe_data_received(self, fd, data): if fd == 1: @@ -18,8 +18,12 @@ class SubprocessProtocol(asyncio.SubprocessProtocol): elif fd == 2: self.proc.stderr(data) - def process_exited(self): - self.proc.exit_future.set_result(True) + if self.proc.expect_index < len(self.proc.expects): + expected = self.proc.expects[self.proc.expect_index] + if re.match(expected['regexp'], data): + self.stdin.write(expected['sendline']) + self.stdin.flush() + self.proc.expect_index += 1 class Subprocess: @@ -45,18 +49,21 @@ class Subprocess: def __init__( self, - *args, + cmd, quiet=None, prefix=None, regexps=None, + expects=None, write=None, flush=None, ): - self.args = args + self.cmd = cmd self.quiet = quiet if quiet is not None else False self.prefix = prefix self.write = write or sys.stdout.buffer.write self.flush = flush or sys.stdout.flush + self.expects = expects or [] + self.expect_index = 0 self.started = False self.waited = False self.out_raw = bytearray() @@ -76,32 +83,27 @@ class Subprocess: self.output( self.colors.bgray.encode() + b'+ ' - + ' '.join([ - arg.replace('\n', '\\n') - for arg in self.args - ]).encode() + + self.cmd.encode() + self.colors.reset.encode(), highlight=False ) - # Get a reference to the event loop as we plan to use - # low-level APIs. + # The following is a copy of what asyncio.create_subprocess_shell does + # except we inject our own SubprocessStreamProtocol subclass: it might + # need an update as new python releases come out. loop = asyncio.get_running_loop() - - self.exit_future = asyncio.Future(loop=loop) - - if len(self.args) == 1 and ' ' in self.args[0]: - args = ['sh', '-euc', self.args[0]] - else: - args = self.args - - # Create the subprocess controlled by DateProtocol; - # redirect the standard output into a pipe. - self.transport, self.protocol = await loop.subprocess_exec( - lambda: SubprocessProtocol(self), - *args, - stdin=None, + self.transport, self.protocol = await loop.subprocess_shell( + lambda: SubprocessProtocol( + self, + limit=asyncio.subprocess.streams._DEFAULT_LIMIT, + loop=loop, + ), + self.cmd, + stdin=asyncio.subprocess.PIPE if self.expects else sys.stdin, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) + self.proc = asyncio.subprocess.Process(self.transport, self.protocol, loop) self.started = True @@ -110,13 +112,7 @@ class Subprocess: await self.start() if not self.waited: - # Wait for the subprocess exit using the process_exited() - # method of the protocol. - await self.exit_future - - # Close the stdout pipe. - self.transport.close() - + await self.proc.communicate() self.waited = True return self diff --git a/tests/test_proc.py b/tests/test_proc.py index 0042d97..b2bbb63 100644 --- a/tests/test_proc.py +++ b/tests/test_proc.py @@ -8,7 +8,6 @@ from shlax import Proc @pytest.mark.parametrize( 'args', ( - ['sh', '-c', 'echo hi'], ['echo hi'], ['sh -c "echo hi"'], ) @@ -195,3 +194,23 @@ async def test_highlight_if_not_colored(): } ).wait() proc.write.assert_called_with(b'h\x1b[31mi\n') + + +@pytest.mark.asyncio +async def test_expect(): + proc = Proc( + 'echo "x?"; read x; echo x=$x; echo "z?"; read z; echo z=$z', + expects=[ + dict( + regexp=b'x?', + sendline=b'y\n', + ), + dict( + regexp=b'z?', + sendline=b'w\n', + ) + ], + quiet=True, + ) + await proc.wait() + assert proc.out == 'x?\nx=y\nz?\nz=w'