Add expect feature
This commit is contained in:
parent
217777f60e
commit
93f40dc7dd
@ -7,10 +7,10 @@ import sys
|
|||||||
from .colors import colors
|
from .colors import colors
|
||||||
|
|
||||||
|
|
||||||
class SubprocessProtocol(asyncio.SubprocessProtocol):
|
class SubprocessProtocol(asyncio.subprocess.SubprocessStreamProtocol):
|
||||||
def __init__(self, proc):
|
def __init__(self, proc, *args, **kwargs):
|
||||||
self.proc = proc
|
self.proc = proc
|
||||||
self.output = bytearray()
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def pipe_data_received(self, fd, data):
|
def pipe_data_received(self, fd, data):
|
||||||
if fd == 1:
|
if fd == 1:
|
||||||
@ -18,8 +18,12 @@ class SubprocessProtocol(asyncio.SubprocessProtocol):
|
|||||||
elif fd == 2:
|
elif fd == 2:
|
||||||
self.proc.stderr(data)
|
self.proc.stderr(data)
|
||||||
|
|
||||||
def process_exited(self):
|
if self.proc.expect_index < len(self.proc.expects):
|
||||||
self.proc.exit_future.set_result(True)
|
expected = self.proc.expects[self.proc.expect_index]
|
||||||
|
if re.match(expected['regexp'], data):
|
||||||
|
self.stdin.write(expected['sendline'])
|
||||||
|
self.stdin.flush()
|
||||||
|
self.proc.expect_index += 1
|
||||||
|
|
||||||
|
|
||||||
class Subprocess:
|
class Subprocess:
|
||||||
@ -45,18 +49,21 @@ class Subprocess:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
cmd,
|
||||||
quiet=None,
|
quiet=None,
|
||||||
prefix=None,
|
prefix=None,
|
||||||
regexps=None,
|
regexps=None,
|
||||||
|
expects=None,
|
||||||
write=None,
|
write=None,
|
||||||
flush=None,
|
flush=None,
|
||||||
):
|
):
|
||||||
self.args = args
|
self.cmd = cmd
|
||||||
self.quiet = quiet if quiet is not None else False
|
self.quiet = quiet if quiet is not None else False
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.write = write or sys.stdout.buffer.write
|
self.write = write or sys.stdout.buffer.write
|
||||||
self.flush = flush or sys.stdout.flush
|
self.flush = flush or sys.stdout.flush
|
||||||
|
self.expects = expects or []
|
||||||
|
self.expect_index = 0
|
||||||
self.started = False
|
self.started = False
|
||||||
self.waited = False
|
self.waited = False
|
||||||
self.out_raw = bytearray()
|
self.out_raw = bytearray()
|
||||||
@ -76,32 +83,27 @@ class Subprocess:
|
|||||||
self.output(
|
self.output(
|
||||||
self.colors.bgray.encode()
|
self.colors.bgray.encode()
|
||||||
+ b'+ '
|
+ b'+ '
|
||||||
+ ' '.join([
|
+ self.cmd.encode()
|
||||||
arg.replace('\n', '\\n')
|
|
||||||
for arg in self.args
|
|
||||||
]).encode()
|
|
||||||
+ self.colors.reset.encode(),
|
+ self.colors.reset.encode(),
|
||||||
highlight=False
|
highlight=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get a reference to the event loop as we plan to use
|
# The following is a copy of what asyncio.create_subprocess_shell does
|
||||||
# low-level APIs.
|
# except we inject our own SubprocessStreamProtocol subclass: it might
|
||||||
|
# need an update as new python releases come out.
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
self.transport, self.protocol = await loop.subprocess_shell(
|
||||||
self.exit_future = asyncio.Future(loop=loop)
|
lambda: SubprocessProtocol(
|
||||||
|
self,
|
||||||
if len(self.args) == 1 and ' ' in self.args[0]:
|
limit=asyncio.subprocess.streams._DEFAULT_LIMIT,
|
||||||
args = ['sh', '-euc', self.args[0]]
|
loop=loop,
|
||||||
else:
|
),
|
||||||
args = self.args
|
self.cmd,
|
||||||
|
stdin=asyncio.subprocess.PIPE if self.expects else sys.stdin,
|
||||||
# Create the subprocess controlled by DateProtocol;
|
stdout=asyncio.subprocess.PIPE,
|
||||||
# redirect the standard output into a pipe.
|
stderr=asyncio.subprocess.PIPE,
|
||||||
self.transport, self.protocol = await loop.subprocess_exec(
|
|
||||||
lambda: SubprocessProtocol(self),
|
|
||||||
*args,
|
|
||||||
stdin=None,
|
|
||||||
)
|
)
|
||||||
|
self.proc = asyncio.subprocess.Process(self.transport, self.protocol, loop)
|
||||||
|
|
||||||
self.started = True
|
self.started = True
|
||||||
|
|
||||||
@ -110,13 +112,7 @@ class Subprocess:
|
|||||||
await self.start()
|
await self.start()
|
||||||
|
|
||||||
if not self.waited:
|
if not self.waited:
|
||||||
# Wait for the subprocess exit using the process_exited()
|
await self.proc.communicate()
|
||||||
# method of the protocol.
|
|
||||||
await self.exit_future
|
|
||||||
|
|
||||||
# Close the stdout pipe.
|
|
||||||
self.transport.close()
|
|
||||||
|
|
||||||
self.waited = True
|
self.waited = True
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from shlax import Proc
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'args',
|
'args',
|
||||||
(
|
(
|
||||||
['sh', '-c', 'echo hi'],
|
|
||||||
['echo hi'],
|
['echo hi'],
|
||||||
['sh -c "echo hi"'],
|
['sh -c "echo hi"'],
|
||||||
)
|
)
|
||||||
@ -195,3 +194,23 @@ async def test_highlight_if_not_colored():
|
|||||||
}
|
}
|
||||||
).wait()
|
).wait()
|
||||||
proc.write.assert_called_with(b'h\x1b[31mi\n')
|
proc.write.assert_called_with(b'h\x1b[31mi\n')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_expect():
|
||||||
|
proc = Proc(
|
||||||
|
'echo "x?"; read x; echo x=$x; echo "z?"; read z; echo z=$z',
|
||||||
|
expects=[
|
||||||
|
dict(
|
||||||
|
regexp=b'x?',
|
||||||
|
sendline=b'y\n',
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
regexp=b'z?',
|
||||||
|
sendline=b'w\n',
|
||||||
|
)
|
||||||
|
],
|
||||||
|
quiet=True,
|
||||||
|
)
|
||||||
|
await proc.wait()
|
||||||
|
assert proc.out == 'x?\nx=y\nz?\nz=w'
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user