From fbfd8126c86af8d10e96dda7a345a33afd41f091 Mon Sep 17 00:00:00 2001 From: Robey Pointer Date: Sun, 6 Jul 2008 16:08:15 -0700 Subject: [PATCH] [project @ robey@lag.net-20080706230815-v2ybqxm237zw0wa0] add a callback method that can be used to track get/put progress in SFTPClient. suggested by Phil Schwartz. --- paramiko/sftp_client.py | 16 ++++++++++++++-- tests/test_sftp.py | 10 ++++++++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/paramiko/sftp_client.py b/paramiko/sftp_client.py index 338d19b..16728d1 100644 --- a/paramiko/sftp_client.py +++ b/paramiko/sftp_client.py @@ -522,7 +522,7 @@ class SFTPClient (BaseSFTP): """ return self._cwd - def put(self, localpath, remotepath): + def put(self, localpath, remotepath, callback=None): """ Copy a local file (C{localpath}) to the SFTP server as C{remotepath}. Any exception raised by operations will be passed through. This @@ -534,12 +534,16 @@ class SFTPClient (BaseSFTP): @type localpath: str @param remotepath: the destination path on the SFTP server @type remotepath: str + @param callback: optional callback function that accepts the bytes + transferred so far and the total bytes to be transferred + @type callback: function(int, int) @return: an object containing attributes about the given file (since 1.7.4) @rtype: SFTPAttributes @since: 1.4 """ + file_size = os.stat(localpath).st_size fl = file(localpath, 'rb') fr = self.file(remotepath, 'wb') fr.set_pipelined(True) @@ -550,6 +554,8 @@ class SFTPClient (BaseSFTP): break fr.write(data) size += len(data) + if callback is not None: + callback(size, file_size) fl.close() fr.close() s = self.stat(remotepath) @@ -557,7 +563,7 @@ class SFTPClient (BaseSFTP): raise IOError('size mismatch in put! %d != %d' % (s.st_size, size)) return s - def get(self, remotepath, localpath): + def get(self, remotepath, localpath, callback=None): """ Copy a remote file (C{remotepath}) from the SFTP server to the local host as C{localpath}. Any exception raised by operations will be @@ -567,10 +573,14 @@ class SFTPClient (BaseSFTP): @type remotepath: str @param localpath: the destination path on the local host @type localpath: str + @param callback: optional callback function that accepts the bytes + transferred so far and the total bytes to be transferred + @type callback: function(int, int) @since: 1.4 """ fr = self.file(remotepath, 'rb') + file_size = self.stat(remotepath).st_size fr.prefetch() fl = file(localpath, 'wb') size = 0 @@ -580,6 +590,8 @@ class SFTPClient (BaseSFTP): break fl.write(data) size += len(data) + if callback is not None: + callback(size, file_size) fl.close() fr.close() s = os.stat(localpath) diff --git a/tests/test_sftp.py b/tests/test_sftp.py index ab5b818..edc0599 100755 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -560,19 +560,25 @@ class SFTPTest (unittest.TestCase): f = open(localname, 'wb') f.write(text) f.close() - sftp.put(localname, FOLDER + '/bunny.txt') + saved_progress = [] + def progress_callback(x, y): + saved_progress.append((x, y)) + sftp.put(localname, FOLDER + '/bunny.txt', progress_callback) f = sftp.open(FOLDER + '/bunny.txt', 'r') self.assertEquals(text, f.read(128)) f.close() + self.assertEquals((41, 41), saved_progress[-1]) os.unlink(localname) localname = os.tempnam() - sftp.get(FOLDER + '/bunny.txt', localname) + saved_progress = [] + sftp.get(FOLDER + '/bunny.txt', localname, progress_callback) f = open(localname, 'rb') self.assertEquals(text, f.read(128)) f.close() + self.assertEquals((41, 41), saved_progress[-1]) os.unlink(localname) sftp.unlink(FOLDER + '/bunny.txt')