add a callback method that can be used to track get/put progress in
SFTPClient. suggested by Phil Schwartz.
This commit is contained in:
Robey Pointer 2008-07-06 16:08:15 -07:00
parent e8748645a3
commit fbfd8126c8
2 changed files with 22 additions and 4 deletions

View File

@ -522,7 +522,7 @@ class SFTPClient (BaseSFTP):
""" """
return self._cwd return self._cwd
def put(self, localpath, remotepath): def put(self, localpath, remotepath, callback=None):
""" """
Copy a local file (C{localpath}) to the SFTP server as C{remotepath}. Copy a local file (C{localpath}) to the SFTP server as C{remotepath}.
Any exception raised by operations will be passed through. This Any exception raised by operations will be passed through. This
@ -534,12 +534,16 @@ class SFTPClient (BaseSFTP):
@type localpath: str @type localpath: str
@param remotepath: the destination path on the SFTP server @param remotepath: the destination path on the SFTP server
@type remotepath: str @type remotepath: str
@param callback: optional callback function that accepts the bytes
transferred so far and the total bytes to be transferred
@type callback: function(int, int)
@return: an object containing attributes about the given file @return: an object containing attributes about the given file
(since 1.7.4) (since 1.7.4)
@rtype: SFTPAttributes @rtype: SFTPAttributes
@since: 1.4 @since: 1.4
""" """
file_size = os.stat(localpath).st_size
fl = file(localpath, 'rb') fl = file(localpath, 'rb')
fr = self.file(remotepath, 'wb') fr = self.file(remotepath, 'wb')
fr.set_pipelined(True) fr.set_pipelined(True)
@ -550,6 +554,8 @@ class SFTPClient (BaseSFTP):
break break
fr.write(data) fr.write(data)
size += len(data) size += len(data)
if callback is not None:
callback(size, file_size)
fl.close() fl.close()
fr.close() fr.close()
s = self.stat(remotepath) s = self.stat(remotepath)
@ -557,7 +563,7 @@ class SFTPClient (BaseSFTP):
raise IOError('size mismatch in put! %d != %d' % (s.st_size, size)) raise IOError('size mismatch in put! %d != %d' % (s.st_size, size))
return s return s
def get(self, remotepath, localpath): def get(self, remotepath, localpath, callback=None):
""" """
Copy a remote file (C{remotepath}) from the SFTP server to the local Copy a remote file (C{remotepath}) from the SFTP server to the local
host as C{localpath}. Any exception raised by operations will be host as C{localpath}. Any exception raised by operations will be
@ -567,10 +573,14 @@ class SFTPClient (BaseSFTP):
@type remotepath: str @type remotepath: str
@param localpath: the destination path on the local host @param localpath: the destination path on the local host
@type localpath: str @type localpath: str
@param callback: optional callback function that accepts the bytes
transferred so far and the total bytes to be transferred
@type callback: function(int, int)
@since: 1.4 @since: 1.4
""" """
fr = self.file(remotepath, 'rb') fr = self.file(remotepath, 'rb')
file_size = self.stat(remotepath).st_size
fr.prefetch() fr.prefetch()
fl = file(localpath, 'wb') fl = file(localpath, 'wb')
size = 0 size = 0
@ -580,6 +590,8 @@ class SFTPClient (BaseSFTP):
break break
fl.write(data) fl.write(data)
size += len(data) size += len(data)
if callback is not None:
callback(size, file_size)
fl.close() fl.close()
fr.close() fr.close()
s = os.stat(localpath) s = os.stat(localpath)

View File

@ -560,19 +560,25 @@ class SFTPTest (unittest.TestCase):
f = open(localname, 'wb') f = open(localname, 'wb')
f.write(text) f.write(text)
f.close() f.close()
sftp.put(localname, FOLDER + '/bunny.txt') saved_progress = []
def progress_callback(x, y):
saved_progress.append((x, y))
sftp.put(localname, FOLDER + '/bunny.txt', progress_callback)
f = sftp.open(FOLDER + '/bunny.txt', 'r') f = sftp.open(FOLDER + '/bunny.txt', 'r')
self.assertEquals(text, f.read(128)) self.assertEquals(text, f.read(128))
f.close() f.close()
self.assertEquals((41, 41), saved_progress[-1])
os.unlink(localname) os.unlink(localname)
localname = os.tempnam() localname = os.tempnam()
sftp.get(FOLDER + '/bunny.txt', localname) saved_progress = []
sftp.get(FOLDER + '/bunny.txt', localname, progress_callback)
f = open(localname, 'rb') f = open(localname, 'rb')
self.assertEquals(text, f.read(128)) self.assertEquals(text, f.read(128))
f.close() f.close()
self.assertEquals((41, 41), saved_progress[-1])
os.unlink(localname) os.unlink(localname)
sftp.unlink(FOLDER + '/bunny.txt') sftp.unlink(FOLDER + '/bunny.txt')