[project @ Arch-1:robey@lag.net--2005-master-shake%paramiko--dev--1--patch-70]

add SFTPFile.prefetch() to allow pre-fetching a file that will be downloaded in full -- quick testing showed this could speed up downloads 3x or more
This commit is contained in:
Robey Pointer 2005-10-24 06:19:56 +00:00
parent c986f92dc5
commit f65edffbfb
3 changed files with 136 additions and 22 deletions

View File

@ -22,6 +22,7 @@ Client-mode SFTP support.
import errno import errno
import os import os
import weakref
from paramiko.sftp import * from paramiko.sftp import *
from paramiko.sftp_attr import SFTPAttributes from paramiko.sftp_attr import SFTPAttributes
from paramiko.sftp_file import SFTPFile from paramiko.sftp_file import SFTPFile
@ -57,8 +58,8 @@ class SFTPClient (BaseSFTP):
self.ultra_debug = False self.ultra_debug = False
self.request_number = 1 self.request_number = 1
self._cwd = None self._cwd = None
# FIXME: after we require python 2.4, use set :) # request # -> SFTPFile
self._expecting = [] self._expecting = weakref.WeakValueDictionary()
if type(sock) is Channel: if type(sock) is Channel:
# override default logger # override default logger
transport = self.sock.get_transport() transport = self.sock.get_transport()
@ -491,6 +492,7 @@ class SFTPClient (BaseSFTP):
@since: 1.4 @since: 1.4
""" """
fr = self.file(remotepath, 'rb') fr = self.file(remotepath, 'rb')
fr.prefetch()
fl = file(localpath, 'wb') fl = file(localpath, 'wb')
size = 0 size = 0
while True: while True:
@ -510,10 +512,10 @@ class SFTPClient (BaseSFTP):
def _request(self, t, *arg): def _request(self, t, *arg):
num = self._async_request(t, *arg) num = self._async_request(type(None), t, *arg)
return self._read_response(num) return self._read_response(num)
def _async_request(self, t, *arg): def _async_request(self, fileobj, t, *arg):
msg = Message() msg = Message()
msg.add_int(self.request_number) msg.add_int(self.request_number)
for item in arg: for item in arg:
@ -529,8 +531,8 @@ class SFTPClient (BaseSFTP):
raise Exception('unknown type for %r type %r' % (item, type(item))) raise Exception('unknown type for %r type %r' % (item, type(item)))
self._send_packet(t, str(msg)) self._send_packet(t, str(msg))
num = self.request_number num = self.request_number
self._expecting[num] = fileobj
self.request_number += 1 self.request_number += 1
self._expecting.append(num)
return num return num
def _read_response(self, waitfor=None): def _read_response(self, waitfor=None):
@ -539,18 +541,26 @@ class SFTPClient (BaseSFTP):
msg = Message(data) msg = Message(data)
num = msg.get_int() num = msg.get_int()
if num not in self._expecting: if num not in self._expecting:
raise SFTPError('Expected response from %r, got response #%d' % # might be response for a file that was closed before responses came back
(self._expected, num)) self._log(DEBUG, 'Unexpected response #%d' % (num,))
self._expecting.remove(num) continue
if t == CMD_STATUS: fileobj = self._expecting[num]
self._convert_status(msg) del self._expecting[num]
if (waitfor is None) or (num == waitfor): if num == waitfor:
break # synchronous
return t, msg if t == CMD_STATUS:
self._convert_status(msg)
return t, msg
if fileobj is not type(None):
fileobj._async_response(t, msg)
if waitfor is None:
# just doing a single check
return
def _finish_responses(self): def _finish_responses(self, fileobj):
while len(self._expecting) > 0: while fileobj in self._expecting.values():
self._read_response() self._read_response()
fileobj._check_exception()
def _convert_status(self, msg): def _convert_status(self, msg):
""" """

View File

@ -20,6 +20,7 @@
L{SFTPFile} L{SFTPFile}
""" """
import threading
from paramiko.common import * from paramiko.common import *
from paramiko.sftp import * from paramiko.sftp import *
from paramiko.file import BufferedFile from paramiko.file import BufferedFile
@ -41,6 +42,8 @@ class SFTPFile (BufferedFile):
self.handle = handle self.handle = handle
BufferedFile._set_mode(self, mode, bufsize) BufferedFile._set_mode(self, mode, bufsize)
self.pipelined = False self.pipelined = False
self._prefetching = False
self._saved_exception = None
def __del__(self): def __del__(self):
self.close() self.close()
@ -56,7 +59,7 @@ class SFTPFile (BufferedFile):
if self._closed: if self._closed:
return return
if self.pipelined: if self.pipelined:
self.sftp._finish_responses() self.sftp._finish_responses(self)
BufferedFile.close(self) BufferedFile.close(self)
try: try:
self.sftp._request(CMD_CLOSE, self.handle) self.sftp._request(CMD_CLOSE, self.handle)
@ -67,8 +70,29 @@ class SFTPFile (BufferedFile):
# may have outlived the Transport connection # may have outlived the Transport connection
pass pass
def _read_prefetch(self, size):
while (self._prefetch_so_far <= self._realpos) and \
(self._prefetch_so_far < self._prefetch_size) and not self._closed:
self.sftp._read_response()
self._check_exception()
k = self._prefetch_data.keys()
k.sort()
while (len(k) > 0) and (k[0] + len(self._prefetch_data[k[0]]) <= self._realpos):
# done with that block
del self._prefetch_data[k[0]]
k.pop(0)
if len(k) == 0:
self._prefetching = False
return ''
assert k[0] <= self._realpos
buf_offset = self._realpos - k[0]
buf_length = len(self._prefetch_data[k[0]]) - buf_offset
return self._prefetch_data[k[0]][buf_offset : buf_offset + buf_length]
def _read(self, size): def _read(self, size):
size = min(size, self.MAX_REQUEST_SIZE) size = min(size, self.MAX_REQUEST_SIZE)
if self._prefetching:
return self._read_prefetch(size)
t, msg = self.sftp._request(CMD_READ, self.handle, long(self._realpos), int(size)) t, msg = self.sftp._request(CMD_READ, self.handle, long(self._realpos), int(size))
if t != CMD_DATA: if t != CMD_DATA:
raise SFTPError('Expected data') raise SFTPError('Expected data')
@ -77,9 +101,10 @@ class SFTPFile (BufferedFile):
def _write(self, data): def _write(self, data):
# may write less than requested if it would exceed max packet size # may write less than requested if it would exceed max packet size
chunk = min(len(data), self.MAX_REQUEST_SIZE) chunk = min(len(data), self.MAX_REQUEST_SIZE)
self.sftp._async_request(CMD_WRITE, self.handle, long(self._realpos), str(data[:chunk])) req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos),
str(data[:chunk]))
if not self.pipelined or self.sftp.sock.recv_ready(): if not self.pipelined or self.sftp.sock.recv_ready():
t, msg = self.sftp._read_response() t, msg = self.sftp._read_response(req)
if t != CMD_STATUS: if t != CMD_STATUS:
raise SFTPError('Expected status') raise SFTPError('Expected status')
self.sftp._convert_status(msg) self.sftp._convert_status(msg)
@ -218,6 +243,33 @@ class SFTPFile (BufferedFile):
""" """
self.pipelined = pipelined self.pipelined = pipelined
def prefetch(self):
"""
Pre-fetch the remaining contents of this file in anticipation of
future L{read} calls. If reading the entire file, pre-fetching can
dramatically improve the download speed by avoiding roundtrip latency.
The file's contents are incrementally buffered in a background thread.
@since: 1.5.1
"""
size = self.stat().st_size
# queue up async reads for the rest of the file
self._prefetching = True
self._prefetch_so_far = self._realpos
self._prefetch_size = size
self._prefetch_data = {}
t = threading.Thread(target=self._prefetch)
t.setDaemon(True)
t.start()
def _prefetch(self):
n = self._realpos
size = self._prefetch_size
while n < size:
chunk = min(self.MAX_REQUEST_SIZE, size - n)
self.sftp._async_request(self, CMD_READ, self.handle, long(n), int(chunk))
n += chunk
### internals... ### internals...
@ -227,3 +279,24 @@ class SFTPFile (BufferedFile):
return self.stat().st_size return self.stat().st_size
except: except:
return 0 return 0
def _async_response(self, t, msg):
if t == CMD_STATUS:
# save exception and re-raise it on next file operation
try:
self.sftp._convert_status(msg)
except Exception, x:
self._saved_exception = x
return
if t != CMD_DATA:
raise SFTPError('Expected data')
data = msg.get_string()
self._prefetch_data[self._prefetch_so_far] = data
self._prefetch_so_far += len(data)
def _check_exception(self):
"if there's a saved exception, raise & clear it"
if self._saved_exception is not None:
x = self._saved_exception
self._saved_exception = None
raise x

View File

@ -23,11 +23,15 @@ a real actual sftp server is contacted, and a new folder is created there to
do test file operations in (so no existing files will be harmed). do test file operations in (so no existing files will be harmed).
""" """
import sys, os import logging
import os
import random import random
import logging, threading import sys
import threading
import time
import unittest
import paramiko, unittest import paramiko
from stub_sftp import StubServer, StubSFTPServer from stub_sftp import StubServer, StubSFTPServer
from loop import LoopSocket from loop import LoopSocket
@ -432,12 +436,13 @@ class SFTPTest (unittest.TestCase):
def test_E_big_file(self): def test_E_big_file(self):
""" """
write a 1MB file, with no linefeeds, using line buffering. write a 1MB file with no buffering.
""" """
global g_big_file_test global g_big_file_test
if not g_big_file_test: if not g_big_file_test:
return return
kblob = (1024 * 'x') kblob = (1024 * 'x')
start = time.time()
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w') f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
for n in range(1024): for n in range(1024):
@ -448,6 +453,18 @@ class SFTPTest (unittest.TestCase):
sys.stderr.write(' ') sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
start = time.time()
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
for n in range(1024):
data = f.read(1024)
self.assertEqual(data, kblob)
f.close()
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
finally: finally:
sftp.remove('%s/hongry.txt' % FOLDER) sftp.remove('%s/hongry.txt' % FOLDER)
@ -459,6 +476,7 @@ class SFTPTest (unittest.TestCase):
if not g_big_file_test: if not g_big_file_test:
return return
kblob = (1024 * 'x') kblob = (1024 * 'x')
start = time.time()
try: try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w') f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
f.set_pipelined(True) f.set_pipelined(True)
@ -470,6 +488,19 @@ class SFTPTest (unittest.TestCase):
sys.stderr.write(' ') sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024) self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
start = time.time()
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
f.prefetch()
for n in range(1024):
data = f.read(1024)
self.assertEqual(data, kblob)
f.close()
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
finally: finally:
sftp.remove('%s/hongry.txt' % FOLDER) sftp.remove('%s/hongry.txt' % FOLDER)