diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py index fd76207..41a14a2 100644 --- a/paramiko/sftp_client.py +++ b/paramiko/sftp_client.py @@ -22,6 +22,7 @@ Client-mode SFTP support. import errno import os +import weakref from paramiko.sftp import * from paramiko.sftp_attr import SFTPAttributes from paramiko.sftp_file import SFTPFile @@ -57,8 +58,8 @@ class SFTPClient (BaseSFTP): self.ultra_debug = False self.request_number = 1 self._cwd = None - # FIXME: after we require python 2.4, use set :) - self._expecting = [] + # request # -> SFTPFile + self._expecting = weakref.WeakValueDictionary() if type(sock) is Channel: # override default logger transport = self.sock.get_transport() @@ -491,6 +492,7 @@ class SFTPClient (BaseSFTP): @since: 1.4 """ fr = self.file(remotepath, 'rb') + fr.prefetch() fl = file(localpath, 'wb') size = 0 while True: @@ -510,10 +512,10 @@ class SFTPClient (BaseSFTP): def _request(self, t, *arg): - num = self._async_request(t, *arg) + num = self._async_request(type(None), t, *arg) return self._read_response(num) - def _async_request(self, t, *arg): + def _async_request(self, fileobj, t, *arg): msg = Message() msg.add_int(self.request_number) for item in arg: @@ -529,8 +531,8 @@ class SFTPClient (BaseSFTP): raise Exception('unknown type for %r type %r' % (item, type(item))) self._send_packet(t, str(msg)) num = self.request_number + self._expecting[num] = fileobj self.request_number += 1 - self._expecting.append(num) return num def _read_response(self, waitfor=None): @@ -539,18 +541,26 @@ class SFTPClient (BaseSFTP): msg = Message(data) num = msg.get_int() if num not in self._expecting: - raise SFTPError('Expected response from %r, got response #%d' % - (self._expected, num)) - self._expecting.remove(num) - if t == CMD_STATUS: - self._convert_status(msg) - if (waitfor is None) or (num == waitfor): - break - return t, msg + # might be response for a file that was closed before responses came back + self._log(DEBUG, 'Unexpected response #%d' % (num,)) + continue + fileobj = self._expecting[num] + del self._expecting[num] + if num == waitfor: + # synchronous + if t == CMD_STATUS: + self._convert_status(msg) + return t, msg + if fileobj is not type(None): + fileobj._async_response(t, msg) + if waitfor is None: + # just doing a single check + return - def _finish_responses(self): - while len(self._expecting) > 0: + def _finish_responses(self, fileobj): + while fileobj in self._expecting.values(): self._read_response() + fileobj._check_exception() def _convert_status(self, msg): """ diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py index 3770f2d..4ecf9c4 100644 --- a/paramiko/sftp_file.py +++ b/paramiko/sftp_file.py @@ -20,6 +20,7 @@ L{SFTPFile} """ +import threading from paramiko.common import * from paramiko.sftp import * from paramiko.file import BufferedFile @@ -41,6 +42,8 @@ class SFTPFile (BufferedFile): self.handle = handle BufferedFile._set_mode(self, mode, bufsize) self.pipelined = False + self._prefetching = False + self._saved_exception = None def __del__(self): self.close() @@ -56,7 +59,7 @@ class SFTPFile (BufferedFile): if self._closed: return if self.pipelined: - self.sftp._finish_responses() + self.sftp._finish_responses(self) BufferedFile.close(self) try: self.sftp._request(CMD_CLOSE, self.handle) @@ -67,8 +70,29 @@ class SFTPFile (BufferedFile): # may have outlived the Transport connection pass + def _read_prefetch(self, size): + while (self._prefetch_so_far <= self._realpos) and \ + (self._prefetch_so_far < self._prefetch_size) and not self._closed: + self.sftp._read_response() + self._check_exception() + k = self._prefetch_data.keys() + k.sort() + while (len(k) > 0) and (k[0] + len(self._prefetch_data[k[0]]) <= self._realpos): + # done with that block + del self._prefetch_data[k[0]] + k.pop(0) + if len(k) == 0: + self._prefetching = False + return '' + assert k[0] <= self._realpos + buf_offset = self._realpos - k[0] + buf_length = len(self._prefetch_data[k[0]]) - buf_offset + return self._prefetch_data[k[0]][buf_offset : buf_offset + buf_length] + def _read(self, size): size = min(size, self.MAX_REQUEST_SIZE) + if self._prefetching: + return self._read_prefetch(size) t, msg = self.sftp._request(CMD_READ, self.handle, long(self._realpos), int(size)) if t != CMD_DATA: raise SFTPError('Expected data') @@ -77,9 +101,10 @@ class SFTPFile (BufferedFile): def _write(self, data): # may write less than requested if it would exceed max packet size chunk = min(len(data), self.MAX_REQUEST_SIZE) - self.sftp._async_request(CMD_WRITE, self.handle, long(self._realpos), str(data[:chunk])) + req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), + str(data[:chunk])) if not self.pipelined or self.sftp.sock.recv_ready(): - t, msg = self.sftp._read_response() + t, msg = self.sftp._read_response(req) if t != CMD_STATUS: raise SFTPError('Expected status') self.sftp._convert_status(msg) @@ -217,6 +242,33 @@ class SFTPFile (BufferedFile): @since: 1.5 """ self.pipelined = pipelined + + def prefetch(self): + """ + Pre-fetch the remaining contents of this file in anticipation of + future L{read} calls. If reading the entire file, pre-fetching can + dramatically improve the download speed by avoiding roundtrip latency. + The file's contents are incrementally buffered in a background thread. + + @since: 1.5.1 + """ + size = self.stat().st_size + # queue up async reads for the rest of the file + self._prefetching = True + self._prefetch_so_far = self._realpos + self._prefetch_size = size + self._prefetch_data = {} + t = threading.Thread(target=self._prefetch) + t.setDaemon(True) + t.start() + + def _prefetch(self): + n = self._realpos + size = self._prefetch_size + while n < size: + chunk = min(self.MAX_REQUEST_SIZE, size - n) + self.sftp._async_request(self, CMD_READ, self.handle, long(n), int(chunk)) + n += chunk ### internals... @@ -227,3 +279,24 @@ class SFTPFile (BufferedFile): return self.stat().st_size except: return 0 + + def _async_response(self, t, msg): + if t == CMD_STATUS: + # save exception and re-raise it on next file operation + try: + self.sftp._convert_status(msg) + except Exception, x: + self._saved_exception = x + return + if t != CMD_DATA: + raise SFTPError('Expected data') + data = msg.get_string() + self._prefetch_data[self._prefetch_so_far] = data + self._prefetch_so_far += len(data) + + def _check_exception(self): + "if there's a saved exception, raise & clear it" + if self._saved_exception is not None: + x = self._saved_exception + self._saved_exception = None + raise x diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 4981936..4c5065e 100755 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -23,11 +23,15 @@ a real actual sftp server is contacted, and a new folder is created there to do test file operations in (so no existing files will be harmed). """ -import sys, os +import logging +import os import random -import logging, threading +import sys +import threading +import time +import unittest -import paramiko, unittest +import paramiko from stub_sftp import StubServer, StubSFTPServer from loop import LoopSocket @@ -432,12 +436,13 @@ class SFTPTest (unittest.TestCase): def test_E_big_file(self): """ - write a 1MB file, with no linefeeds, using line buffering. + write a 1MB file with no buffering. """ global g_big_file_test if not g_big_file_test: return kblob = (1024 * 'x') + start = time.time() try: f = sftp.open('%s/hongry.txt' % FOLDER, 'w') for n in range(1024): @@ -448,6 +453,18 @@ class SFTPTest (unittest.TestCase): sys.stderr.write(' ') self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) + end = time.time() + sys.stderr.write('%ds ' % round(end - start)) + + start = time.time() + f = sftp.open('%s/hongry.txt' % FOLDER, 'r') + for n in range(1024): + data = f.read(1024) + self.assertEqual(data, kblob) + f.close() + + end = time.time() + sys.stderr.write('%ds ' % round(end - start)) finally: sftp.remove('%s/hongry.txt' % FOLDER) @@ -459,6 +476,7 @@ class SFTPTest (unittest.TestCase): if not g_big_file_test: return kblob = (1024 * 'x') + start = time.time() try: f = sftp.open('%s/hongry.txt' % FOLDER, 'w') f.set_pipelined(True) @@ -470,6 +488,19 @@ class SFTPTest (unittest.TestCase): sys.stderr.write(' ') self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) + end = time.time() + sys.stderr.write('%ds ' % round(end - start)) + + start = time.time() + f = sftp.open('%s/hongry.txt' % FOLDER, 'r') + f.prefetch() + for n in range(1024): + data = f.read(1024) + self.assertEqual(data, kblob) + f.close() + + end = time.time() + sys.stderr.write('%ds ' % round(end - start)) finally: sftp.remove('%s/hongry.txt' % FOLDER)