From 145ceab54c454f6872e34c426d6f9c0aadccbd85 Mon Sep 17 00:00:00 2001 From: Robey Pointer Date: Tue, 22 Aug 2006 19:55:38 -0700 Subject: [PATCH] [project @ robey@lag.net-20060823025538-3f8a4d761d7d4118] when a file is open for append, don't stat to get the file position unless the user asks for it explicitly --- paramiko/file.py | 60 ++++++++++++++++++++++--------------------- paramiko/sftp_file.py | 11 ++++++-- tests/test_sftp.py | 21 +++++++++++++++ 3 files changed, 61 insertions(+), 31 deletions(-) diff --git a/paramiko/file.py b/paramiko/file.py index 1971ce0..98a28f6 100644 --- a/paramiko/file.py +++ b/paramiko/file.py @@ -23,13 +23,6 @@ BufferedFile. from cStringIO import StringIO -_FLAG_READ = 0x1 -_FLAG_WRITE = 0x2 -_FLAG_APPEND = 0x4 -_FLAG_BINARY = 0x10 -_FLAG_BUFFERED = 0x20 -_FLAG_LINE_BUFFERED = 0x40 -_FLAG_UNIVERSAL_NEWLINE = 0x80 class BufferedFile (object): @@ -44,6 +37,14 @@ class BufferedFile (object): SEEK_CUR = 1 SEEK_END = 2 + FLAG_READ = 0x1 + FLAG_WRITE = 0x2 + FLAG_APPEND = 0x4 + FLAG_BINARY = 0x10 + FLAG_BUFFERED = 0x20 + FLAG_LINE_BUFFERED = 0x40 + FLAG_UNIVERSAL_NEWLINE = 0x80 + def __init__(self): self.newlines = None self._flags = 0 @@ -123,7 +124,7 @@ class BufferedFile (object): """ if self._closed: raise IOError('File is closed') - if not (self._flags & _FLAG_READ): + if not (self._flags & self.FLAG_READ): raise IOError('File not open for reading') if (size is None) or (size < 0): # go for broke @@ -148,7 +149,7 @@ class BufferedFile (object): return result while len(self._rbuffer) < size: read_size = size - len(self._rbuffer) - if self._flags & _FLAG_BUFFERED: + if self._flags & self.FLAG_BUFFERED: read_size = max(self._bufsize, read_size) try: new_data = self._read(read_size) @@ -184,11 +185,11 @@ class BufferedFile (object): # it's almost silly how complex this function is. if self._closed: raise IOError('File is closed') - if not (self._flags & _FLAG_READ): + if not (self._flags & self.FLAG_READ): raise IOError('File not open for reading') line = self._rbuffer while True: - if self._at_trailing_cr and (self._flags & _FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0): + if self._at_trailing_cr and (self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0): # edge case: the newline may be '\r\n' and we may have read # only the first '\r' last time. if line[0] == '\n': @@ -209,7 +210,7 @@ class BufferedFile (object): n = size - len(line) else: n = self._bufsize - if ('\n' in line) or ((self._flags & _FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)): + if ('\n' in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)): break try: new_data = self._read(n) @@ -223,7 +224,7 @@ class BufferedFile (object): self._realpos += len(new_data) # find the newline pos = line.find('\n') - if self._flags & _FLAG_UNIVERSAL_NEWLINE: + if self._flags & self.FLAG_UNIVERSAL_NEWLINE: rpos = line.find('\r') if (rpos >= 0) and ((rpos < pos) or (pos < 0)): pos = rpos @@ -295,6 +296,8 @@ class BufferedFile (object): @return: file position (in bytes). @rtype: int """ + if self._flags & self.FLAG_APPEND: + return self._get_size() return self._pos def write(self, data): @@ -309,13 +312,13 @@ class BufferedFile (object): """ if self._closed: raise IOError('File is closed') - if not (self._flags & _FLAG_WRITE): + if not (self._flags & self.FLAG_WRITE): raise IOError('File not open for writing') - if not (self._flags & _FLAG_BUFFERED): + if not (self._flags & self.FLAG_BUFFERED): self._write_all(data) return self._wbuffer.write(data) - if self._flags & _FLAG_LINE_BUFFERED: + if self._flags & self.FLAG_LINE_BUFFERED: # only scan the new data for linefeed, to avoid wasting time. last_newline_pos = data.rfind('\n') if last_newline_pos >= 0: @@ -397,22 +400,21 @@ class BufferedFile (object): # apparently, line buffering only affects writes. reads are only # buffered if you call readline (directly or indirectly: iterating # over a file will indirectly call readline). - self._flags |= _FLAG_BUFFERED | _FLAG_LINE_BUFFERED + self._flags |= self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED elif bufsize > 1: self._bufsize = bufsize - self._flags |= _FLAG_BUFFERED + self._flags |= self.FLAG_BUFFERED if ('r' in mode) or ('+' in mode): - self._flags |= _FLAG_READ + self._flags |= self.FLAG_READ if ('w' in mode) or ('+' in mode): - self._flags |= _FLAG_WRITE + self._flags |= self.FLAG_WRITE if ('a' in mode): - self._flags |= _FLAG_WRITE | _FLAG_APPEND - self._size = self._get_size() - self._pos = self._realpos = self._size + self._flags |= self.FLAG_WRITE | self.FLAG_APPEND + self._pos = self._realpos = -1 if ('b' in mode): - self._flags |= _FLAG_BINARY + self._flags |= self.FLAG_BINARY if ('U' in mode): - self._flags |= _FLAG_UNIVERSAL_NEWLINE + self._flags |= self.FLAG_UNIVERSAL_NEWLINE # built-in file objects have this attribute to store which kinds of # line terminations they've seen: # @@ -424,9 +426,9 @@ class BufferedFile (object): while len(data) > 0: count = self._write(data) data = data[count:] - if self._flags & _FLAG_APPEND: - self._size += count - self._pos = self._realpos = self._size + if self._flags & self.FLAG_APPEND: + # even if we used to know our seek position, we don't now. + self._pos = self._realpos = -1 else: self._pos += count self._realpos += count @@ -436,7 +438,7 @@ class BufferedFile (object): # silliness about tracking what kinds of newlines we've seen. # i don't understand why it can be None, a string, or a tuple, instead # of just always being a tuple, but we'll emulate that behavior anyway. - if not (self._flags & _FLAG_UNIVERSAL_NEWLINE): + if not (self._flags & self.FLAG_UNIVERSAL_NEWLINE): return if self.newlines is None: self.newlines = newline diff --git a/paramiko/sftp_file.py b/paramiko/sftp_file.py index 1e5478b..e29d557 100644 --- a/paramiko/sftp_file.py +++ b/paramiko/sftp_file.py @@ -159,8 +159,11 @@ class SFTPFile (BufferedFile): def _write(self, data): # may write less than requested if it would exceed max packet size chunk = min(len(data), self.MAX_REQUEST_SIZE) - req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), - str(data[:chunk])) + if self._flags & self.FLAG_APPEND: + pos = 0 + else: + pos = self._realpos + req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(pos), str(data[:chunk])) if not self.pipelined or self.sftp.sock.recv_ready(): t, msg = self.sftp._read_response(req) if t != CMD_STATUS: @@ -204,6 +207,10 @@ class SFTPFile (BufferedFile): def seek(self, offset, whence=0): self.flush() + if (self._flags & self.FLAG_APPEND) and (self._realpos == -1) and (whence != self.SEEK_END): + # this is still legal for O_RDWR ('a+'), but we need to figure out + # where we are -- we lost track of it during writes. + self._realpos = self._pos = self._get_size() if whence == self.SEEK_SET: self._realpos = self._pos = offset elif whence == self.SEEK_CUR: diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 1e2785d..db21013 100755 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -648,3 +648,24 @@ class SFTPTest (unittest.TestCase): f.close() finally: sftp.unlink(FOLDER + '/zero') + + def test_M_seek_append(self): + """ + verify that seek does't affect writes during append. + """ + f = sftp.open(FOLDER + '/append.txt', 'a') + try: + f.write('first line\nsecond line\n') + f.seek(11, f.SEEK_SET) + f.write('third line\n') + f.close() + + f = sftp.open(FOLDER + '/append.txt', 'r') + self.assertEqual(f.stat().st_size, 34) + self.assertEqual(f.readline(), 'first line\n') + self.assertEqual(f.readline(), 'second line\n') + self.assertEqual(f.readline(), 'third line\n') + f.close() + finally: + sftp.remove(FOLDER + '/append.txt') +