Add expect feature

This commit is contained in:
jpic 2021-11-12 03:56:48 +01:00
parent 217777f60e
commit 93f40dc7dd
2 changed files with 50 additions and 35 deletions

View File

@ -7,10 +7,10 @@ import sys
from .colors import colors from .colors import colors
class SubprocessProtocol(asyncio.SubprocessProtocol): class SubprocessProtocol(asyncio.subprocess.SubprocessStreamProtocol):
def __init__(self, proc): def __init__(self, proc, *args, **kwargs):
self.proc = proc self.proc = proc
self.output = bytearray() super().__init__(*args, **kwargs)
def pipe_data_received(self, fd, data): def pipe_data_received(self, fd, data):
if fd == 1: if fd == 1:
@ -18,8 +18,12 @@ class SubprocessProtocol(asyncio.SubprocessProtocol):
elif fd == 2: elif fd == 2:
self.proc.stderr(data) self.proc.stderr(data)
def process_exited(self): if self.proc.expect_index < len(self.proc.expects):
self.proc.exit_future.set_result(True) expected = self.proc.expects[self.proc.expect_index]
if re.match(expected['regexp'], data):
self.stdin.write(expected['sendline'])
self.stdin.flush()
self.proc.expect_index += 1
class Subprocess: class Subprocess:
@ -45,18 +49,21 @@ class Subprocess:
def __init__( def __init__(
self, self,
*args, cmd,
quiet=None, quiet=None,
prefix=None, prefix=None,
regexps=None, regexps=None,
expects=None,
write=None, write=None,
flush=None, flush=None,
): ):
self.args = args self.cmd = cmd
self.quiet = quiet if quiet is not None else False self.quiet = quiet if quiet is not None else False
self.prefix = prefix self.prefix = prefix
self.write = write or sys.stdout.buffer.write self.write = write or sys.stdout.buffer.write
self.flush = flush or sys.stdout.flush self.flush = flush or sys.stdout.flush
self.expects = expects or []
self.expect_index = 0
self.started = False self.started = False
self.waited = False self.waited = False
self.out_raw = bytearray() self.out_raw = bytearray()
@ -76,32 +83,27 @@ class Subprocess:
self.output( self.output(
self.colors.bgray.encode() self.colors.bgray.encode()
+ b'+ ' + b'+ '
+ ' '.join([ + self.cmd.encode()
arg.replace('\n', '\\n')
for arg in self.args
]).encode()
+ self.colors.reset.encode(), + self.colors.reset.encode(),
highlight=False highlight=False
) )
# Get a reference to the event loop as we plan to use # The following is a copy of what asyncio.create_subprocess_shell does
# low-level APIs. # except we inject our own SubprocessStreamProtocol subclass: it might
# need an update as new python releases come out.
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
self.transport, self.protocol = await loop.subprocess_shell(
self.exit_future = asyncio.Future(loop=loop) lambda: SubprocessProtocol(
self,
if len(self.args) == 1 and ' ' in self.args[0]: limit=asyncio.subprocess.streams._DEFAULT_LIMIT,
args = ['sh', '-euc', self.args[0]] loop=loop,
else: ),
args = self.args self.cmd,
stdin=asyncio.subprocess.PIPE if self.expects else sys.stdin,
# Create the subprocess controlled by DateProtocol; stdout=asyncio.subprocess.PIPE,
# redirect the standard output into a pipe. stderr=asyncio.subprocess.PIPE,
self.transport, self.protocol = await loop.subprocess_exec(
lambda: SubprocessProtocol(self),
*args,
stdin=None,
) )
self.proc = asyncio.subprocess.Process(self.transport, self.protocol, loop)
self.started = True self.started = True
@ -110,13 +112,7 @@ class Subprocess:
await self.start() await self.start()
if not self.waited: if not self.waited:
# Wait for the subprocess exit using the process_exited() await self.proc.communicate()
# method of the protocol.
await self.exit_future
# Close the stdout pipe.
self.transport.close()
self.waited = True self.waited = True
return self return self

View File

@ -8,7 +8,6 @@ from shlax import Proc
@pytest.mark.parametrize( @pytest.mark.parametrize(
'args', 'args',
( (
['sh', '-c', 'echo hi'],
['echo hi'], ['echo hi'],
['sh -c "echo hi"'], ['sh -c "echo hi"'],
) )
@ -195,3 +194,23 @@ async def test_highlight_if_not_colored():
} }
).wait() ).wait()
proc.write.assert_called_with(b'h\x1b[31mi\n') proc.write.assert_called_with(b'h\x1b[31mi\n')
@pytest.mark.asyncio
async def test_expect():
proc = Proc(
'echo "x?"; read x; echo x=$x; echo "z?"; read z; echo z=$z',
expects=[
dict(
regexp=b'x?',
sendline=b'y\n',
),
dict(
regexp=b'z?',
sendline=b'w\n',
)
],
quiet=True,
)
await proc.wait()
assert proc.out == 'x?\nx=y\nz?\nz=w'