我仍然没有找到任何可以使用的代码,所以写了下面的版本。我仍然渴望找到一个做这种事情的开源模块。
"""Remote control a machine given an ssh client on Linux"""
from pwd import getpwuid
from socket import socket, AF_UNIX, SOCK_STREAM, error
from os import getuid, write, close, unlink, read
import select, errno
from subprocess import PIPE, Popen, check_call
from time import time, sleep
from tempfile import mkstemp
def whoami():
"""figure out user ID"""
return getpwuid(getuid()).pw_name
def arg_escape(text):
"""Escape things that confuse shells"""
return text.replace('(', '\\(').replace(')', '\\)')
def try_open_socket(socket_name):
"""can we talk to socket_name?"""
sock = socket(AF_UNIX, SOCK_STREAM)
try:
sock.connect(socket_name)
except error:
return False
else:
return True
class ProcessTimeoutError(Exception):
"""Indicates that a process failed to finish in alotted time"""
pass
class ConnectionTimeoutError(Exception):
"""Indicates that it was not possible to connect in alotted time"""
class CalledProcessError(Exception):
"""Indicates non-zero exit of a process we expected to exit cleanly"""
def local_run_with_timeout(command, timeout=60, check=True):
"""Run a command with a timeout after which it will be SIGKILLed.
If check is set raise CalledProcessError if the command fails.
Based on the standard library subprocess.Popen._communicate_with_poll.
"""
process = Popen(command, shell=False, stdout=PIPE, stderr=PIPE)
poller = select.poll()
start = time()
fd2file = {}
content = {}
for stream in [process.stdout, process.stderr]:
poller.register(stream.fileno(), select.POLLIN | select.POLLPRI)
fd2file[stream.fileno()] = stream
content[stream] = ''
vout = lambda: content[process.stdout]
verr = lambda: content[process.stderr]
while fd2file:
delta = time() - start
if delta > timeout:
process.kill()
raise ProcessTimeoutError(command, timeout, vout(), verr())
try:
ready = poller.poll(timeout-delta)
except select.error, exc:
if exc.args[0] == errno.EINTR:
continue
raise
for fileno, mode in ready:
if mode & (select.POLLIN | select.POLLPRI):
data = read(fileno, 4096)
content[fd2file[fileno]] += data
if data:
continue
fd2file[fileno].close()
fd2file.pop(fileno)
poller.unregister(fileno)
process.wait()
if check and process.returncode != 0:
raise CalledProcessError(process.returncode, delta, command,
vout(), verr())
return (process.returncode, vout(), verr())
class Endpoint:
"""Perform operations on a remote machine"""
def __init__(self, host, user=None, process_timeout=10,
connection_timeout=20):
self.host = host
self.user = user
self.target = ((self.user+'@') if self.user else '') + self.host
self.process_timeout = process_timeout
self.connection_timeout = connection_timeout
def start(self):
"""Start the SSH connection and return the unix socket name.
Requires http://sourceforge.net/projects/sshpass/ and
http://software.clapper.org/daemonize/ to be installed
"""
socket_name = '/tmp/' + whoami() + '-ssh-' + self.target
if not try_open_socket(socket_name):
check_call(['daemonize', '/usr/bin/ssh',
'-N', '-M', '-S', socket_name, self.target])
start = time()
while not try_open_socket(socket_name):
delta = time() - start
if delta > self.connection_timeout:
raise ConnectionTimeoutError(delta, self.target)
sleep(1)
return socket_name
def call(self, command, timeout=None, check=False):
"""Run command with timeout"""
if not timeout:
timeout = self.process_timeout
socket_name = self.start()
if type(command) == type(''):
command = command.split()
command_escape = [arg_escape(x) for x in command]
command_string = ' '.join(command_escape)
return local_run_with_timeout(
['/usr/bin/ssh', '-S', socket_name,
self.target, command_string], timeout=timeout, check=check)
def check_call(self, command, timeout=None):
"""Run command with timeout"""
exitcode, stdout, stderr = self.call(command, timeout=timeout,
check=True)
return stdout, stderr
def isdir(self, directory):
"""Return true if a directory exists"""
return 'directory\n' in self.call(['stat', directory])[1]
def write_file(self, content, filename):
"""Store content on filename"""
handle, name = mkstemp()
try:
write(handle, content)
close(handle)
socket_name = self.start()
exitcode, stdout, stderr = local_run_with_timeout(
['/usr/bin/scp', '-o', 'ControlPath='+socket_name,
'-o', 'ControlMaster=auto', name, self.target+':'+filename],
timeout=self.process_timeout, check=True)
finally:
unlink(name)
def test_check_call():
"""Run through some test cases. """
tep = Endpoint('localhost')
assert 'dev' in tep.check_call('ls /')[0]
assert tep.call('false')[0] != 0
def test_basic_timeout():
"""Ensure timeouts trigger"""
import pytest # "easy_install pytest" FTW
start = time()
with pytest.raises(ProcessTimeoutError):
Endpoint('localhost').call('sleep 5', timeout=0.2)
assert not (time()-start > 3)
def test_timeout_output():
"""check timeouts embed stdout"""
import pytest # "easy_install pytest" FTW
with pytest.raises(ProcessTimeoutError):
Endpoint('localhost').call('find /', timeout=0.2)
def test_non_zero_exit():
"""chek check_call raises an CalledProcessError on timeout"""
import pytest # "easy_install pytest" FTW
with pytest.raises(CalledProcessError):
Endpoint('localhost').check_call('false')
def test_fs():
"""check filesystem operations"""
tep = Endpoint('localhost')
assert tep.isdir('/usr')
assert not tep.isdir(str(time()))
tep.write_file('hello world', '/tmp/test')
tep.check_call(['grep','hello','/tmp/test'])