From 40f9a4c04915cf5d88d4faf80d084b827ab65cf2 Mon Sep 17 00:00:00 2001 From: Nikolaus Schulz Date: Mon, 28 Dec 2009 00:40:28 +0100 Subject: Use safe methods to open the archive mbox and an existing mbox file --- archivemail.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 57 insertions(+), 8 deletions(-) diff --git a/archivemail.py b/archivemail.py index 5e93765..45b160d 100755 --- a/archivemail.py +++ b/archivemail.py @@ -331,13 +331,12 @@ class Mbox(mailbox.UnixMailbox): """ assert(path) self._locked = False - try: - self.original_atime = os.path.getatime(path) - self.original_mtime = os.path.getmtime(path) - self.starting_size = os.path.getsize(path) - self.mbox_file = open(path, "r+") - except IOError, msg: - unexpected_error(msg) + fd = safe_open_existing(path) + st = os.fstat(fd) + self.original_atime = st.st_atime + self.original_mtime = st.st_mtime + self.starting_size = st.st_size + self.mbox_file = os.fdopen(fd, "r+") self.mbox_file_name = path mailbox.UnixMailbox.__init__(self, self.mbox_file) @@ -604,7 +603,8 @@ class ArchiveMbox(TempMbox): if not options.no_compress: final_name = final_name + ".gz" vprint("writing back '%s' to '%s'" % (self.mbox_file_name, final_name)) - final_archive = open(final_name, "a") + fd = safe_open(final_name) + final_archive = os.fdopen(fd, "a") shutil.copyfileobj(self.mbox_file, final_archive) final_archive.close() self.remove() @@ -1664,6 +1664,55 @@ def get_filename(msg): return msg.fp._file.name raise +def safe_open_create(filename): + """Create and open a file in a NFSv2-safe way, and return a r/w file descriptor. + The new file is created with mode 600.""" + # This is essentially a simplified version of the dotlocking function. + vprint("Creating file '%s'" % filename) + dir, basename = os.path.split(filename) + # We rely on tempfile.mkstemp to create files safely and with 600 mode. + fd, pre_name = tempfile.mkstemp(prefix=basename+".pre-", dir=dir) + try: + try: + os.link(pre_name, filename) + except OSError, e: + if os.fstat(fd).st_nlink == 2: + pass + else: + raise + finally: + os.unlink(pre_name) + return fd + +def safe_open_existing(filename): + """Safely open an existing file, and return a r/w file descriptor.""" + lst = os.lstat(filename) + if stat.S_ISLNK(lst.st_mode): + unexpected_error("file '%s' is a symlink." % filename) + fd = os.open(filename, os.O_RDWR) + fst = os.fstat(fd) + if fst.st_nlink != 1: + unexpected_error("file '%s' has %d hard links." % \ + (filename, fst.st_nlink)) + if stat.S_ISDIR(fst.st_mode): + unexpected_error("file '%s' is a directory." % filename) + for i in stat.ST_DEV, stat.ST_INO, stat.ST_UID, stat.ST_GID, stat.ST_MODE, stat.ST_NLINK: + if fst[i] != lst[i]: + unexpected_error("file status changed unexpectedly") + return fd + +def safe_open(filename): + """Safely open a file, creating it if it doesn't exist, and return a + r/w file descriptor.""" + # This borrows from postfix code. + vprint("Opening archive...") + try: + fd = safe_open_existing(filename) + except OSError, e: + if e.errno != errno.ENOENT: raise + fd = safe_open_create(filename) + return fd + # this is where it all happens, folks if __name__ == '__main__': main() -- cgit v1.2.3