Handle/fix handling of EINTR errors in a few places.

(cherry picked from commit 351bdb72e539c373985e108c89f61839f3acdd2a)

Conflicts:
	paramiko/agent.py
	paramiko/client.py
	paramiko/transport.py
This commit is contained in:
Douglas Turk 2012-09-03 14:48:00 +10:00 committed by Jeff Forcier
parent 7ead8d9c70
commit 681a465f32
6 changed files with 52 additions and 8 deletions

View File

@ -35,6 +35,7 @@ from paramiko.message import Message
from paramiko.pkey import PKey from paramiko.pkey import PKey
from paramiko.channel import Channel from paramiko.channel import Channel
from paramiko.common import io_sleep from paramiko.common import io_sleep
from paramiko.util import retry_on_signal
SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \ SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \
SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15) SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15)
@ -202,7 +203,7 @@ class AgentClientProxy(object):
if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'): if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'):
conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try: try:
conn.connect(os.environ['SSH_AUTH_SOCK']) retry_on_signal(lambda: conn.connect(os.environ['SSH_AUTH_SOCK']))
except: except:
# probably a dangling env var: the ssh agent is gone # probably a dangling env var: the ssh agent is gone
return return

View File

@ -34,6 +34,7 @@ from paramiko.resource import ResourceManager
from paramiko.rsakey import RSAKey from paramiko.rsakey import RSAKey
from paramiko.ssh_exception import SSHException, BadHostKeyException from paramiko.ssh_exception import SSHException, BadHostKeyException
from paramiko.transport import Transport from paramiko.transport import Transport
from paramiko.util import retry_on_signal
SSH_PORT = 22 SSH_PORT = 22
@ -293,7 +294,7 @@ class SSHClient (object):
sock.settimeout(timeout) sock.settimeout(timeout)
except: except:
pass pass
sock.connect(addr) retry_on_signal(lambda: sock.connect(addr))
t = self._transport = Transport(sock) t = self._transport = Transport(sock)
t.use_compression(compress=compress) t.use_compression(compress=compress)
if self._log_channel is not None: if self._log_channel is not None:

View File

@ -241,23 +241,23 @@ class Packetizer (object):
def write_all(self, out): def write_all(self, out):
self.__keepalive_last = time.time() self.__keepalive_last = time.time()
while len(out) > 0: while len(out) > 0:
got_timeout = False retry_write = False
try: try:
n = self.__socket.send(out) n = self.__socket.send(out)
except socket.timeout: except socket.timeout:
got_timeout = True retry_write = True
except socket.error, e: except socket.error, e:
if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN): if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
got_timeout = True retry_write = True
elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR): elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
# syscall interrupted; try again # syscall interrupted; try again
pass retry_write = True
else: else:
n = -1 n = -1
except Exception: except Exception:
# could be: (32, 'Broken pipe') # could be: (32, 'Broken pipe')
n = -1 n = -1
if got_timeout: if retry_write:
n = 0 n = 0
if self.__closed: if self.__closed:
n = -1 n = -1
@ -469,6 +469,12 @@ class Packetizer (object):
break break
except socket.timeout: except socket.timeout:
pass pass
except EnvironmentError, e:
if ((type(e.args) is tuple) and (len(e.args) > 0) and
(e.args[0] == errno.EINTR)):
pass
else:
raise
if self.__closed: if self.__closed:
raise EOFError() raise EOFError()
now = time.time() now = time.time()

View File

@ -45,6 +45,7 @@ from paramiko.rsakey import RSAKey
from paramiko.server import ServerInterface from paramiko.server import ServerInterface
from paramiko.sftp_client import SFTPClient from paramiko.sftp_client import SFTPClient
from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException
from paramiko.util import retry_on_signal
from Crypto import Random from Crypto import Random
from Crypto.Cipher import Blowfish, AES, DES3, ARC4 from Crypto.Cipher import Blowfish, AES, DES3, ARC4
@ -289,7 +290,7 @@ class Transport (threading.Thread):
addr = sockaddr addr = sockaddr
sock = socket.socket(af, socket.SOCK_STREAM) sock = socket.socket(af, socket.SOCK_STREAM)
try: try:
sock.connect((hostname, port)) retry_on_signal(lambda: sock.connect((hostname, port)))
except socket.error, e: except socket.error, e:
reason = str(e) reason = str(e)
else: else:

View File

@ -24,6 +24,7 @@ from __future__ import generators
import array import array
from binascii import hexlify, unhexlify from binascii import hexlify, unhexlify
import errno
import sys import sys
import struct import struct
import traceback import traceback
@ -270,6 +271,14 @@ def get_logger(name):
l.addFilter(_pfilter) l.addFilter(_pfilter)
return l return l
def retry_on_signal(function):
"""Retries function until it doesn't raise an EINTR error"""
while True:
try:
return function()
except EnvironmentError, e:
if e.errno != errno.EINTR:
raise
class Counter (object): class Counter (object):
"""Stateful counter for CTR mode crypto""" """Stateful counter for CTR mode crypto"""

View File

@ -22,6 +22,7 @@ Some unit tests for utility functions.
from binascii import hexlify from binascii import hexlify
import cStringIO import cStringIO
import errno
import os import os
import unittest import unittest
from Crypto.Hash import SHA from Crypto.Hash import SHA
@ -177,3 +178,28 @@ Host *
ssh.util.lookup_ssh_host_config(host, config), ssh.util.lookup_ssh_host_config(host, config),
{'hostname': host, 'port': '22'} {'hostname': host, 'port': '22'}
) )
def test_8_eintr_retry(self):
self.assertEquals('foo', ssh.util.retry_on_signal(lambda: 'foo'))
# Variables that are set by raises_intr
intr_errors_remaining = [3]
call_count = [0]
def raises_intr():
call_count[0] += 1
if intr_errors_remaining[0] > 0:
intr_errors_remaining[0] -= 1
raise IOError(errno.EINTR, 'file', 'interrupted system call')
self.assertTrue(ssh.util.retry_on_signal(raises_intr) is None)
self.assertEquals(0, intr_errors_remaining[0])
self.assertEquals(4, call_count[0])
def raises_ioerror_not_eintr():
raise IOError(errno.ENOENT, 'file', 'file not found')
self.assertRaises(IOError,
lambda: ssh.util.retry_on_signal(raises_ioerror_not_eintr))
def raises_other_exception():
raise AssertionError('foo')
self.assertRaises(AssertionError,
lambda: ssh.util.retry_on_signal(raises_other_exception))