when a file is open for append, don't stat to get the file position unless the user asks for it explicitly
This commit is contained in:
Robey Pointer 2006-08-22 19:55:38 -07:00
parent 738e81033a
commit 145ceab54c
3 changed files with 61 additions and 31 deletions

View File

@ -23,13 +23,6 @@ BufferedFile.
from cStringIO import StringIO from cStringIO import StringIO
_FLAG_READ = 0x1
_FLAG_WRITE = 0x2
_FLAG_APPEND = 0x4
_FLAG_BINARY = 0x10
_FLAG_BUFFERED = 0x20
_FLAG_LINE_BUFFERED = 0x40
_FLAG_UNIVERSAL_NEWLINE = 0x80
class BufferedFile (object): class BufferedFile (object):
@ -44,6 +37,14 @@ class BufferedFile (object):
SEEK_CUR = 1 SEEK_CUR = 1
SEEK_END = 2 SEEK_END = 2
FLAG_READ = 0x1
FLAG_WRITE = 0x2
FLAG_APPEND = 0x4
FLAG_BINARY = 0x10
FLAG_BUFFERED = 0x20
FLAG_LINE_BUFFERED = 0x40
FLAG_UNIVERSAL_NEWLINE = 0x80
def __init__(self): def __init__(self):
self.newlines = None self.newlines = None
self._flags = 0 self._flags = 0
@ -123,7 +124,7 @@ class BufferedFile (object):
""" """
if self._closed: if self._closed:
raise IOError('File is closed') raise IOError('File is closed')
if not (self._flags & _FLAG_READ): if not (self._flags & self.FLAG_READ):
raise IOError('File not open for reading') raise IOError('File not open for reading')
if (size is None) or (size < 0): if (size is None) or (size < 0):
# go for broke # go for broke
@ -148,7 +149,7 @@ class BufferedFile (object):
return result return result
while len(self._rbuffer) < size: while len(self._rbuffer) < size:
read_size = size - len(self._rbuffer) read_size = size - len(self._rbuffer)
if self._flags & _FLAG_BUFFERED: if self._flags & self.FLAG_BUFFERED:
read_size = max(self._bufsize, read_size) read_size = max(self._bufsize, read_size)
try: try:
new_data = self._read(read_size) new_data = self._read(read_size)
@ -184,11 +185,11 @@ class BufferedFile (object):
# it's almost silly how complex this function is. # it's almost silly how complex this function is.
if self._closed: if self._closed:
raise IOError('File is closed') raise IOError('File is closed')
if not (self._flags & _FLAG_READ): if not (self._flags & self.FLAG_READ):
raise IOError('File not open for reading') raise IOError('File not open for reading')
line = self._rbuffer line = self._rbuffer
while True: while True:
if self._at_trailing_cr and (self._flags & _FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0): if self._at_trailing_cr and (self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0):
# edge case: the newline may be '\r\n' and we may have read # edge case: the newline may be '\r\n' and we may have read
# only the first '\r' last time. # only the first '\r' last time.
if line[0] == '\n': if line[0] == '\n':
@ -209,7 +210,7 @@ class BufferedFile (object):
n = size - len(line) n = size - len(line)
else: else:
n = self._bufsize n = self._bufsize
if ('\n' in line) or ((self._flags & _FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)): if ('\n' in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)):
break break
try: try:
new_data = self._read(n) new_data = self._read(n)
@ -223,7 +224,7 @@ class BufferedFile (object):
self._realpos += len(new_data) self._realpos += len(new_data)
# find the newline # find the newline
pos = line.find('\n') pos = line.find('\n')
if self._flags & _FLAG_UNIVERSAL_NEWLINE: if self._flags & self.FLAG_UNIVERSAL_NEWLINE:
rpos = line.find('\r') rpos = line.find('\r')
if (rpos >= 0) and ((rpos < pos) or (pos < 0)): if (rpos >= 0) and ((rpos < pos) or (pos < 0)):
pos = rpos pos = rpos
@ -295,6 +296,8 @@ class BufferedFile (object):
@return: file position (in bytes). @return: file position (in bytes).
@rtype: int @rtype: int
""" """
if self._flags & self.FLAG_APPEND:
return self._get_size()
return self._pos return self._pos
def write(self, data): def write(self, data):
@ -309,13 +312,13 @@ class BufferedFile (object):
""" """
if self._closed: if self._closed:
raise IOError('File is closed') raise IOError('File is closed')
if not (self._flags & _FLAG_WRITE): if not (self._flags & self.FLAG_WRITE):
raise IOError('File not open for writing') raise IOError('File not open for writing')
if not (self._flags & _FLAG_BUFFERED): if not (self._flags & self.FLAG_BUFFERED):
self._write_all(data) self._write_all(data)
return return
self._wbuffer.write(data) self._wbuffer.write(data)
if self._flags & _FLAG_LINE_BUFFERED: if self._flags & self.FLAG_LINE_BUFFERED:
# only scan the new data for linefeed, to avoid wasting time. # only scan the new data for linefeed, to avoid wasting time.
last_newline_pos = data.rfind('\n') last_newline_pos = data.rfind('\n')
if last_newline_pos >= 0: if last_newline_pos >= 0:
@ -397,22 +400,21 @@ class BufferedFile (object):
# apparently, line buffering only affects writes. reads are only # apparently, line buffering only affects writes. reads are only
# buffered if you call readline (directly or indirectly: iterating # buffered if you call readline (directly or indirectly: iterating
# over a file will indirectly call readline). # over a file will indirectly call readline).
self._flags |= _FLAG_BUFFERED | _FLAG_LINE_BUFFERED self._flags |= self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED
elif bufsize > 1: elif bufsize > 1:
self._bufsize = bufsize self._bufsize = bufsize
self._flags |= _FLAG_BUFFERED self._flags |= self.FLAG_BUFFERED
if ('r' in mode) or ('+' in mode): if ('r' in mode) or ('+' in mode):
self._flags |= _FLAG_READ self._flags |= self.FLAG_READ
if ('w' in mode) or ('+' in mode): if ('w' in mode) or ('+' in mode):
self._flags |= _FLAG_WRITE self._flags |= self.FLAG_WRITE
if ('a' in mode): if ('a' in mode):
self._flags |= _FLAG_WRITE | _FLAG_APPEND self._flags |= self.FLAG_WRITE | self.FLAG_APPEND
self._size = self._get_size() self._pos = self._realpos = -1
self._pos = self._realpos = self._size
if ('b' in mode): if ('b' in mode):
self._flags |= _FLAG_BINARY self._flags |= self.FLAG_BINARY
if ('U' in mode): if ('U' in mode):
self._flags |= _FLAG_UNIVERSAL_NEWLINE self._flags |= self.FLAG_UNIVERSAL_NEWLINE
# built-in file objects have this attribute to store which kinds of # built-in file objects have this attribute to store which kinds of
# line terminations they've seen: # line terminations they've seen:
# <http://www.python.org/doc/current/lib/built-in-funcs.html> # <http://www.python.org/doc/current/lib/built-in-funcs.html>
@ -424,9 +426,9 @@ class BufferedFile (object):
while len(data) > 0: while len(data) > 0:
count = self._write(data) count = self._write(data)
data = data[count:] data = data[count:]
if self._flags & _FLAG_APPEND: if self._flags & self.FLAG_APPEND:
self._size += count # even if we used to know our seek position, we don't now.
self._pos = self._realpos = self._size self._pos = self._realpos = -1
else: else:
self._pos += count self._pos += count
self._realpos += count self._realpos += count
@ -436,7 +438,7 @@ class BufferedFile (object):
# silliness about tracking what kinds of newlines we've seen. # silliness about tracking what kinds of newlines we've seen.
# i don't understand why it can be None, a string, or a tuple, instead # i don't understand why it can be None, a string, or a tuple, instead
# of just always being a tuple, but we'll emulate that behavior anyway. # of just always being a tuple, but we'll emulate that behavior anyway.
if not (self._flags & _FLAG_UNIVERSAL_NEWLINE): if not (self._flags & self.FLAG_UNIVERSAL_NEWLINE):
return return
if self.newlines is None: if self.newlines is None:
self.newlines = newline self.newlines = newline

View File

@ -159,8 +159,11 @@ class SFTPFile (BufferedFile):
def _write(self, data): def _write(self, data):
# may write less than requested if it would exceed max packet size # may write less than requested if it would exceed max packet size
chunk = min(len(data), self.MAX_REQUEST_SIZE) chunk = min(len(data), self.MAX_REQUEST_SIZE)
req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), if self._flags & self.FLAG_APPEND:
str(data[:chunk])) pos = 0
else:
pos = self._realpos
req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(pos), str(data[:chunk]))
if not self.pipelined or self.sftp.sock.recv_ready(): if not self.pipelined or self.sftp.sock.recv_ready():
t, msg = self.sftp._read_response(req) t, msg = self.sftp._read_response(req)
if t != CMD_STATUS: if t != CMD_STATUS:
@ -204,6 +207,10 @@ class SFTPFile (BufferedFile):
def seek(self, offset, whence=0): def seek(self, offset, whence=0):
self.flush() self.flush()
if (self._flags & self.FLAG_APPEND) and (self._realpos == -1) and (whence != self.SEEK_END):
# this is still legal for O_RDWR ('a+'), but we need to figure out
# where we are -- we lost track of it during writes.
self._realpos = self._pos = self._get_size()
if whence == self.SEEK_SET: if whence == self.SEEK_SET:
self._realpos = self._pos = offset self._realpos = self._pos = offset
elif whence == self.SEEK_CUR: elif whence == self.SEEK_CUR:

View File

@ -648,3 +648,24 @@ class SFTPTest (unittest.TestCase):
f.close() f.close()
finally: finally:
sftp.unlink(FOLDER + '/zero') sftp.unlink(FOLDER + '/zero')
def test_M_seek_append(self):
"""
verify that seek does't affect writes during append.
"""
f = sftp.open(FOLDER + '/append.txt', 'a')
try:
f.write('first line\nsecond line\n')
f.seek(11, f.SEEK_SET)
f.write('third line\n')
f.close()
f = sftp.open(FOLDER + '/append.txt', 'r')
self.assertEqual(f.stat().st_size, 34)
self.assertEqual(f.readline(), 'first line\n')
self.assertEqual(f.readline(), 'second line\n')
self.assertEqual(f.readline(), 'third line\n')
f.close()
finally:
sftp.remove(FOLDER + '/append.txt')