[project @ Arch-1:robey@lag.net--2005-master-shake%paramiko--dev--1--patch-70]

add SFTPFile.prefetch() to allow pre-fetching a file that will be downloaded in full -- quick testing showed this could speed up downloads 3x or more
This commit is contained in:
Robey Pointer 2005-10-24 06:19:56 +00:00
parent c986f92dc5
commit f65edffbfb
3 changed files with 136 additions and 22 deletions

View File

@ -22,6 +22,7 @@ Client-mode SFTP support.
import errno
import os
import weakref
from paramiko.sftp import *
from paramiko.sftp_attr import SFTPAttributes
from paramiko.sftp_file import SFTPFile
@ -57,8 +58,8 @@ class SFTPClient (BaseSFTP):
self.ultra_debug = False
self.request_number = 1
self._cwd = None
# FIXME: after we require python 2.4, use set :)
self._expecting = []
# request # -> SFTPFile
self._expecting = weakref.WeakValueDictionary()
if type(sock) is Channel:
# override default logger
transport = self.sock.get_transport()
@ -491,6 +492,7 @@ class SFTPClient (BaseSFTP):
@since: 1.4
"""
fr = self.file(remotepath, 'rb')
fr.prefetch()
fl = file(localpath, 'wb')
size = 0
while True:
@ -510,10 +512,10 @@ class SFTPClient (BaseSFTP):
def _request(self, t, *arg):
num = self._async_request(t, *arg)
num = self._async_request(type(None), t, *arg)
return self._read_response(num)
def _async_request(self, t, *arg):
def _async_request(self, fileobj, t, *arg):
msg = Message()
msg.add_int(self.request_number)
for item in arg:
@ -529,8 +531,8 @@ class SFTPClient (BaseSFTP):
raise Exception('unknown type for %r type %r' % (item, type(item)))
self._send_packet(t, str(msg))
num = self.request_number
self._expecting[num] = fileobj
self.request_number += 1
self._expecting.append(num)
return num
def _read_response(self, waitfor=None):
@ -539,18 +541,26 @@ class SFTPClient (BaseSFTP):
msg = Message(data)
num = msg.get_int()
if num not in self._expecting:
raise SFTPError('Expected response from %r, got response #%d' %
(self._expected, num))
self._expecting.remove(num)
if t == CMD_STATUS:
self._convert_status(msg)
if (waitfor is None) or (num == waitfor):
break
return t, msg
# might be response for a file that was closed before responses came back
self._log(DEBUG, 'Unexpected response #%d' % (num,))
continue
fileobj = self._expecting[num]
del self._expecting[num]
if num == waitfor:
# synchronous
if t == CMD_STATUS:
self._convert_status(msg)
return t, msg
if fileobj is not type(None):
fileobj._async_response(t, msg)
if waitfor is None:
# just doing a single check
return
def _finish_responses(self):
while len(self._expecting) > 0:
def _finish_responses(self, fileobj):
while fileobj in self._expecting.values():
self._read_response()
fileobj._check_exception()
def _convert_status(self, msg):
"""

View File

@ -20,6 +20,7 @@
L{SFTPFile}
"""
import threading
from paramiko.common import *
from paramiko.sftp import *
from paramiko.file import BufferedFile
@ -41,6 +42,8 @@ class SFTPFile (BufferedFile):
self.handle = handle
BufferedFile._set_mode(self, mode, bufsize)
self.pipelined = False
self._prefetching = False
self._saved_exception = None
def __del__(self):
self.close()
@ -56,7 +59,7 @@ class SFTPFile (BufferedFile):
if self._closed:
return
if self.pipelined:
self.sftp._finish_responses()
self.sftp._finish_responses(self)
BufferedFile.close(self)
try:
self.sftp._request(CMD_CLOSE, self.handle)
@ -67,8 +70,29 @@ class SFTPFile (BufferedFile):
# may have outlived the Transport connection
pass
def _read_prefetch(self, size):
while (self._prefetch_so_far <= self._realpos) and \
(self._prefetch_so_far < self._prefetch_size) and not self._closed:
self.sftp._read_response()
self._check_exception()
k = self._prefetch_data.keys()
k.sort()
while (len(k) > 0) and (k[0] + len(self._prefetch_data[k[0]]) <= self._realpos):
# done with that block
del self._prefetch_data[k[0]]
k.pop(0)
if len(k) == 0:
self._prefetching = False
return ''
assert k[0] <= self._realpos
buf_offset = self._realpos - k[0]
buf_length = len(self._prefetch_data[k[0]]) - buf_offset
return self._prefetch_data[k[0]][buf_offset : buf_offset + buf_length]
def _read(self, size):
size = min(size, self.MAX_REQUEST_SIZE)
if self._prefetching:
return self._read_prefetch(size)
t, msg = self.sftp._request(CMD_READ, self.handle, long(self._realpos), int(size))
if t != CMD_DATA:
raise SFTPError('Expected data')
@ -77,9 +101,10 @@ class SFTPFile (BufferedFile):
def _write(self, data):
# may write less than requested if it would exceed max packet size
chunk = min(len(data), self.MAX_REQUEST_SIZE)
self.sftp._async_request(CMD_WRITE, self.handle, long(self._realpos), str(data[:chunk]))
req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos),
str(data[:chunk]))
if not self.pipelined or self.sftp.sock.recv_ready():
t, msg = self.sftp._read_response()
t, msg = self.sftp._read_response(req)
if t != CMD_STATUS:
raise SFTPError('Expected status')
self.sftp._convert_status(msg)
@ -217,6 +242,33 @@ class SFTPFile (BufferedFile):
@since: 1.5
"""
self.pipelined = pipelined
def prefetch(self):
"""
Pre-fetch the remaining contents of this file in anticipation of
future L{read} calls. If reading the entire file, pre-fetching can
dramatically improve the download speed by avoiding roundtrip latency.
The file's contents are incrementally buffered in a background thread.
@since: 1.5.1
"""
size = self.stat().st_size
# queue up async reads for the rest of the file
self._prefetching = True
self._prefetch_so_far = self._realpos
self._prefetch_size = size
self._prefetch_data = {}
t = threading.Thread(target=self._prefetch)
t.setDaemon(True)
t.start()
def _prefetch(self):
n = self._realpos
size = self._prefetch_size
while n < size:
chunk = min(self.MAX_REQUEST_SIZE, size - n)
self.sftp._async_request(self, CMD_READ, self.handle, long(n), int(chunk))
n += chunk
### internals...
@ -227,3 +279,24 @@ class SFTPFile (BufferedFile):
return self.stat().st_size
except:
return 0
def _async_response(self, t, msg):
if t == CMD_STATUS:
# save exception and re-raise it on next file operation
try:
self.sftp._convert_status(msg)
except Exception, x:
self._saved_exception = x
return
if t != CMD_DATA:
raise SFTPError('Expected data')
data = msg.get_string()
self._prefetch_data[self._prefetch_so_far] = data
self._prefetch_so_far += len(data)
def _check_exception(self):
"if there's a saved exception, raise & clear it"
if self._saved_exception is not None:
x = self._saved_exception
self._saved_exception = None
raise x

View File

@ -23,11 +23,15 @@ a real actual sftp server is contacted, and a new folder is created there to
do test file operations in (so no existing files will be harmed).
"""
import sys, os
import logging
import os
import random
import logging, threading
import sys
import threading
import time
import unittest
import paramiko, unittest
import paramiko
from stub_sftp import StubServer, StubSFTPServer
from loop import LoopSocket
@ -432,12 +436,13 @@ class SFTPTest (unittest.TestCase):
def test_E_big_file(self):
"""
write a 1MB file, with no linefeeds, using line buffering.
write a 1MB file with no buffering.
"""
global g_big_file_test
if not g_big_file_test:
return
kblob = (1024 * 'x')
start = time.time()
try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
for n in range(1024):
@ -448,6 +453,18 @@ class SFTPTest (unittest.TestCase):
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
start = time.time()
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
for n in range(1024):
data = f.read(1024)
self.assertEqual(data, kblob)
f.close()
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
finally:
sftp.remove('%s/hongry.txt' % FOLDER)
@ -459,6 +476,7 @@ class SFTPTest (unittest.TestCase):
if not g_big_file_test:
return
kblob = (1024 * 'x')
start = time.time()
try:
f = sftp.open('%s/hongry.txt' % FOLDER, 'w')
f.set_pipelined(True)
@ -470,6 +488,19 @@ class SFTPTest (unittest.TestCase):
sys.stderr.write(' ')
self.assertEqual(sftp.stat('%s/hongry.txt' % FOLDER).st_size, 1024 * 1024)
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
start = time.time()
f = sftp.open('%s/hongry.txt' % FOLDER, 'r')
f.prefetch()
for n in range(1024):
data = f.read(1024)
self.assertEqual(data, kblob)
f.close()
end = time.time()
sys.stderr.write('%ds ' % round(end - start))
finally:
sftp.remove('%s/hongry.txt' % FOLDER)