Add expect feature

This commit is contained in:
jpic 2021-11-12 03:56:48 +01:00
parent 217777f60e
commit 93f40dc7dd
2 changed files with 50 additions and 35 deletions

View File

@ -7,10 +7,10 @@ import sys
from .colors import colors
class SubprocessProtocol(asyncio.SubprocessProtocol):
def __init__(self, proc):
class SubprocessProtocol(asyncio.subprocess.SubprocessStreamProtocol):
def __init__(self, proc, *args, **kwargs):
self.proc = proc
self.output = bytearray()
super().__init__(*args, **kwargs)
def pipe_data_received(self, fd, data):
if fd == 1:
@ -18,8 +18,12 @@ class SubprocessProtocol(asyncio.SubprocessProtocol):
elif fd == 2:
self.proc.stderr(data)
def process_exited(self):
self.proc.exit_future.set_result(True)
if self.proc.expect_index < len(self.proc.expects):
expected = self.proc.expects[self.proc.expect_index]
if re.match(expected['regexp'], data):
self.stdin.write(expected['sendline'])
self.stdin.flush()
self.proc.expect_index += 1
class Subprocess:
@ -45,18 +49,21 @@ class Subprocess:
def __init__(
self,
*args,
cmd,
quiet=None,
prefix=None,
regexps=None,
expects=None,
write=None,
flush=None,
):
self.args = args
self.cmd = cmd
self.quiet = quiet if quiet is not None else False
self.prefix = prefix
self.write = write or sys.stdout.buffer.write
self.flush = flush or sys.stdout.flush
self.expects = expects or []
self.expect_index = 0
self.started = False
self.waited = False
self.out_raw = bytearray()
@ -76,32 +83,27 @@ class Subprocess:
self.output(
self.colors.bgray.encode()
+ b'+ '
+ ' '.join([
arg.replace('\n', '\\n')
for arg in self.args
]).encode()
+ self.cmd.encode()
+ self.colors.reset.encode(),
highlight=False
)
# Get a reference to the event loop as we plan to use
# low-level APIs.
# The following is a copy of what asyncio.create_subprocess_shell does
# except we inject our own SubprocessStreamProtocol subclass: it might
# need an update as new python releases come out.
loop = asyncio.get_running_loop()
self.exit_future = asyncio.Future(loop=loop)
if len(self.args) == 1 and ' ' in self.args[0]:
args = ['sh', '-euc', self.args[0]]
else:
args = self.args
# Create the subprocess controlled by DateProtocol;
# redirect the standard output into a pipe.
self.transport, self.protocol = await loop.subprocess_exec(
lambda: SubprocessProtocol(self),
*args,
stdin=None,
self.transport, self.protocol = await loop.subprocess_shell(
lambda: SubprocessProtocol(
self,
limit=asyncio.subprocess.streams._DEFAULT_LIMIT,
loop=loop,
),
self.cmd,
stdin=asyncio.subprocess.PIPE if self.expects else sys.stdin,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
self.proc = asyncio.subprocess.Process(self.transport, self.protocol, loop)
self.started = True
@ -110,13 +112,7 @@ class Subprocess:
await self.start()
if not self.waited:
# Wait for the subprocess exit using the process_exited()
# method of the protocol.
await self.exit_future
# Close the stdout pipe.
self.transport.close()
await self.proc.communicate()
self.waited = True
return self

View File

@ -8,7 +8,6 @@ from shlax import Proc
@pytest.mark.parametrize(
'args',
(
['sh', '-c', 'echo hi'],
['echo hi'],
['sh -c "echo hi"'],
)
@ -195,3 +194,23 @@ async def test_highlight_if_not_colored():
}
).wait()
proc.write.assert_called_with(b'h\x1b[31mi\n')
@pytest.mark.asyncio
async def test_expect():
proc = Proc(
'echo "x?"; read x; echo x=$x; echo "z?"; read z; echo z=$z',
expects=[
dict(
regexp=b'x?',
sendline=b'y\n',
),
dict(
regexp=b'z?',
sendline=b'w\n',
)
],
quiet=True,
)
await proc.wait()
assert proc.out == 'x?\nx=y\nz?\nz=w'