diff --git a/paramiko/agent.py b/paramiko/agent.py index 7115f17..5d04dce 100644 --- a/paramiko/agent.py +++ b/paramiko/agent.py @@ -35,6 +35,7 @@ from paramiko.message import Message from paramiko.pkey import PKey from paramiko.channel import Channel from paramiko.common import io_sleep +from paramiko.util import retry_on_signal SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \ SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15) @@ -202,7 +203,7 @@ class AgentClientProxy(object): if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'): conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: - conn.connect(os.environ['SSH_AUTH_SOCK']) + retry_on_signal(lambda: conn.connect(os.environ['SSH_AUTH_SOCK'])) except: # probably a dangling env var: the ssh agent is gone return diff --git a/paramiko/client.py b/paramiko/client.py index 557cbb7..3ccb52b 100644 --- a/paramiko/client.py +++ b/paramiko/client.py @@ -34,6 +34,7 @@ from paramiko.resource import ResourceManager from paramiko.rsakey import RSAKey from paramiko.ssh_exception import SSHException, BadHostKeyException from paramiko.transport import Transport +from paramiko.util import retry_on_signal SSH_PORT = 22 @@ -293,7 +294,7 @@ class SSHClient (object): sock.settimeout(timeout) except: pass - sock.connect(addr) + retry_on_signal(lambda: sock.connect(addr)) t = self._transport = Transport(sock) t.use_compression(compress=compress) if self._log_channel is not None: diff --git a/paramiko/packet.py b/paramiko/packet.py index 2f6d692..9782061 100644 --- a/paramiko/packet.py +++ b/paramiko/packet.py @@ -241,23 +241,23 @@ class Packetizer (object): def write_all(self, out): self.__keepalive_last = time.time() while len(out) > 0: - got_timeout = False + retry_write = False try: n = self.__socket.send(out) except socket.timeout: - got_timeout = True + retry_write = True except socket.error, e: if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN): - got_timeout = True + retry_write = True elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR): # syscall interrupted; try again - pass + retry_write = True else: n = -1 except Exception: # could be: (32, 'Broken pipe') n = -1 - if got_timeout: + if retry_write: n = 0 if self.__closed: n = -1 @@ -469,6 +469,12 @@ class Packetizer (object): break except socket.timeout: pass + except EnvironmentError, e: + if ((type(e.args) is tuple) and (len(e.args) > 0) and + (e.args[0] == errno.EINTR)): + pass + else: + raise if self.__closed: raise EOFError() now = time.time() diff --git a/paramiko/transport.py b/paramiko/transport.py index 8174a4c..dd389a8 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -45,6 +45,7 @@ from paramiko.rsakey import RSAKey from paramiko.server import ServerInterface from paramiko.sftp_client import SFTPClient from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException +from paramiko.util import retry_on_signal from Crypto import Random from Crypto.Cipher import Blowfish, AES, DES3, ARC4 @@ -289,7 +290,7 @@ class Transport (threading.Thread): addr = sockaddr sock = socket.socket(af, socket.SOCK_STREAM) try: - sock.connect((hostname, port)) + retry_on_signal(lambda: sock.connect((hostname, port))) except socket.error, e: reason = str(e) else: diff --git a/paramiko/util.py b/paramiko/util.py index 0d6a534..f4bfbec 100644 --- a/paramiko/util.py +++ b/paramiko/util.py @@ -24,6 +24,7 @@ from __future__ import generators import array from binascii import hexlify, unhexlify +import errno import sys import struct import traceback @@ -270,6 +271,14 @@ def get_logger(name): l.addFilter(_pfilter) return l +def retry_on_signal(function): + """Retries function until it doesn't raise an EINTR error""" + while True: + try: + return function() + except EnvironmentError, e: + if e.errno != errno.EINTR: + raise class Counter (object): """Stateful counter for CTR mode crypto""" diff --git a/tests/test_util.py b/tests/test_util.py index ed0607f..59a3d99 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -22,6 +22,7 @@ Some unit tests for utility functions. from binascii import hexlify import cStringIO +import errno import os import unittest from Crypto.Hash import SHA @@ -177,3 +178,28 @@ Host * ssh.util.lookup_ssh_host_config(host, config), {'hostname': host, 'port': '22'} ) + + def test_8_eintr_retry(self): + self.assertEquals('foo', ssh.util.retry_on_signal(lambda: 'foo')) + + # Variables that are set by raises_intr + intr_errors_remaining = [3] + call_count = [0] + def raises_intr(): + call_count[0] += 1 + if intr_errors_remaining[0] > 0: + intr_errors_remaining[0] -= 1 + raise IOError(errno.EINTR, 'file', 'interrupted system call') + self.assertTrue(ssh.util.retry_on_signal(raises_intr) is None) + self.assertEquals(0, intr_errors_remaining[0]) + self.assertEquals(4, call_count[0]) + + def raises_ioerror_not_eintr(): + raise IOError(errno.ENOENT, 'file', 'file not found') + self.assertRaises(IOError, + lambda: ssh.util.retry_on_signal(raises_ioerror_not_eintr)) + + def raises_other_exception(): + raise AssertionError('foo') + self.assertRaises(AssertionError, + lambda: ssh.util.retry_on_signal(raises_other_exception))