[ARVADOS] created: 5fa3808d7587cc7a72acebef991233008f108a0b

git at public.curoverse.com git at public.curoverse.com
Fri Oct 24 17:25:32 EDT 2014


        at  5fa3808d7587cc7a72acebef991233008f108a0b (commit)


commit 5fa3808d7587cc7a72acebef991233008f108a0b
Author: Brett Smith <brett at curoverse.com>
Date:   Fri Oct 24 11:17:11 2014 -0400

    3603: PySDK Collection objects support file-like APIs.
    
    This commit adds an open() method to CollectionReader and
    CollectionWriter.  They mimic the built-in open(), returning objects
    that implement as much as the Python file API as I can reasonably
    manage (except I made readlines() a generator, to be forward-looking).

diff --git a/sdk/python/arvados/arvfile.py b/sdk/python/arvados/arvfile.py
new file mode 100644
index 0000000..ab695e9
--- /dev/null
+++ b/sdk/python/arvados/arvfile.py
@@ -0,0 +1,32 @@
+import functools
+
+class ArvadosFileBase(object):
+    def __init__(self, name, mode):
+        self.name = name
+        self.mode = mode
+        self.closed = False
+
+    @staticmethod
+    def _before_close(orig_func):
+        @functools.wraps(orig_func)
+        def wrapper(self, *args, **kwargs):
+            if self.closed:
+                raise ValueError("I/O operation on closed stream file")
+            return orig_func(self, *args, **kwargs)
+        return wrapper
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        try:
+            self.close()
+        except Exception:
+            if exc_type is None:
+                raise
+
+    def close(self):
+        self.closed = True
+
+    def isatty(self):
+        return False
diff --git a/sdk/python/arvados/collection.py b/sdk/python/arvados/collection.py
index dac75e0..97fedb7 100644
--- a/sdk/python/arvados/collection.py
+++ b/sdk/python/arvados/collection.py
@@ -5,6 +5,7 @@ import re
 from collections import deque
 from stat import *
 
+from .arvfile import ArvadosFileBase
 from keep import *
 from stream import *
 import config
@@ -167,6 +168,23 @@ class CollectionReader(CollectionBase):
         self._manifest_text = ''.join([StreamReader(stream, keep=self._my_keep()).manifest_text() for stream in self._streams])
         #print "result", self._manifest_text
 
+    def open(self, stream_name, file_name=None):
+        self._populate()
+        stream_name, file_name = util.splitstream(stream_name, file_name)
+        keep_client = self._my_keep()
+        for stream_s in self._streams:
+            stream = StreamReader(stream_s, keep_client,
+                                  num_retries=self.num_retries)
+            if stream.name() == stream_name:
+                break
+        else:
+            raise ValueError("stream '{}' not found in Collection".
+                             format(stream_name))
+        try:
+            return stream.files()[file_name]
+        except KeyError:
+            raise ValueError("file '{}' not found in Collection stream '{}'".
+                             format(stream_name, file_name))
 
     def all_streams(self):
         self._populate()
@@ -187,6 +205,25 @@ class CollectionReader(CollectionBase):
             return self._manifest_text
 
 
+class _WriterFile(ArvadosFileBase):
+    def __init__(self, coll_writer, name):
+        super(_WriterFile, self).__init__(name, 'wb')
+        self.dest = coll_writer
+
+    def close(self):
+        super(_WriterFile, self).close()
+        self.dest.finish_current_file()
+
+    @ArvadosFileBase._before_close
+    def write(self, data):
+        self.dest.write(data)
+
+    @ArvadosFileBase._before_close
+    def writelines(self, seq):
+        for data in seq:
+            self.write(data)
+
+
 class CollectionWriter(CollectionBase):
     KEEP_BLOCK_SIZE = 2**26
 
@@ -223,6 +260,7 @@ class CollectionWriter(CollectionBase):
         self._queued_file = None
         self._queued_dirents = deque()
         self._queued_trees = deque()
+        self._last_open = None
 
     def __exit__(self, exc_type, exc_value, traceback):
         if exc_type is None:
@@ -330,6 +368,18 @@ class CollectionWriter(CollectionBase):
         while self._data_buffer_len >= self.KEEP_BLOCK_SIZE:
             self.flush_data()
 
+    def open(self, stream_name, file_name=None):
+        stream_name, file_name = util.splitstream(stream_name, file_name)
+        if self._last_open and not self._last_open.closed:
+            raise errors.AssertionError(
+                "can't open '{}' when '{}' is still open".format(
+                    file_name, self._last_open.name))
+        if stream_name != self.current_stream_name():
+            self.start_new_stream(stream_name)
+        self.set_current_file_name(file_name)
+        self._last_open = _WriterFile(self, file_name)
+        return self._last_open
+
     def flush_data(self):
         data_buffer = ''.join(self._data_buffer)
         if data_buffer:
diff --git a/sdk/python/arvados/stream.py b/sdk/python/arvados/stream.py
index 2c893b6..52cc6a8 100644
--- a/sdk/python/arvados/stream.py
+++ b/sdk/python/arvados/stream.py
@@ -5,6 +5,7 @@ import os
 import re
 import zlib
 
+from .arvfile import ArvadosFileBase
 from arvados.retry import retry_method
 from keep import *
 import config
@@ -90,24 +91,40 @@ def locators_and_ranges(data_locators, range_start, range_size, debug=False):
     return resp
 
 
-class StreamFileReader(object):
+class StreamFileReader(ArvadosFileBase):
+    class NameAttribute(str):
+        # The Python file API provides a plain .name attribute.
+        # Older SDK provided a name() method.
+        # This class provides both, for maximum compatibility.
+        def __call__(self):
+            return self
+
+
     def __init__(self, stream, segments, name):
+        super(StreamFileReader, self).__init__(self.NameAttribute(name), 'rb')
         self._stream = stream
         self.segments = segments
-        self._name = name
+        self._max_segsize = max(seg[1] for seg in segments)
         self._filepos = 0L
         self.num_retries = stream.num_retries
+        self._readline_cache = (-1, '')
 
-    def name(self):
-        return self._name
+    def __iter__(self):
+        return self.readlines()
 
     def decompressed_name(self):
-        return re.sub('\.(bz2|gz)$', '', self._name)
+        return re.sub('\.(bz2|gz)$', '', self.name)
 
     def stream_name(self):
         return self._stream.name()
 
-    def seek(self, pos):
+    @ArvadosFileBase._before_close
+    def seek(self, pos, rel=os.SEEK_SET):
+        """Note that the default is SEEK_SET, not Python's usual SEEK_CUR."""
+        if rel == os.SEEK_CUR:
+            pos += self._filepos
+        elif rel == os.SEEK_END:
+            pos += self.size()
         self._filepos = min(max(pos, 0L), self.size())
 
     def tell(self):
@@ -117,6 +134,7 @@ class StreamFileReader(object):
         n = self.segments[-1]
         return n[OFFSET] + n[BLOCKSIZE]
 
+    @ArvadosFileBase._before_close
     @retry_method
     def read(self, size, num_retries=None):
         """Read up to 'size' bytes from the stream, starting at the current file position"""
@@ -133,6 +151,7 @@ class StreamFileReader(object):
         self._filepos += len(data)
         return data
 
+    @ArvadosFileBase._before_close
     @retry_method
     def readfrom(self, start, size, num_retries=None):
         """Read up to 'size' bytes from the stream, starting at 'start'"""
@@ -145,6 +164,7 @@ class StreamFileReader(object):
                                               num_retries=num_retries))
         return ''.join(data)
 
+    @ArvadosFileBase._before_close
     @retry_method
     def readall(self, size=2**20, num_retries=None):
         while True:
@@ -153,6 +173,35 @@ class StreamFileReader(object):
                 break
             yield data
 
+    @ArvadosFileBase._before_close
+    @retry_method
+    def readline(self, num_retries=None):
+        if self.tell() == self._readline_cache[0]:
+            data = [self._readline_cache[1]]
+        else:
+            data = [self.read(self._max_segsize, num_retries=num_retries)]
+        while data[-1] and ('\n' not in data[-1]):
+            data.append(self.read(self._max_segsize, num_retries=num_retries))
+        data = ''.join(data)
+        try:
+            nextline_index = data.index('\n') + 1
+        except ValueError:
+            nextline_index = len(data)
+        line = data[:nextline_index]
+        rest = data[nextline_index:]
+        self._readline_cache = (self.tell(), rest)
+        return line
+
+    @ArvadosFileBase._before_close
+    @retry_method
+    def readlines(self, num_retries=None):
+        while True:
+            data = self.readline(num_retries=num_retries)
+            if not data:
+                break
+            yield data
+
+    @ArvadosFileBase._before_close
     @retry_method
     def decompress(self, decompress, size, num_retries=None):
         for segment in self.readall(size, num_retries):
@@ -160,20 +209,22 @@ class StreamFileReader(object):
             if data and data != '':
                 yield data
 
+    @ArvadosFileBase._before_close
     @retry_method
     def readall_decompressed(self, size=2**20, num_retries=None):
         self.seek(0)
-        if re.search('\.bz2$', self._name):
+        if re.search('\.bz2$', self.name):
             dc = bz2.BZ2Decompressor()
             return self.decompress(dc.decompress, size,
                                    num_retries=num_retries)
-        elif re.search('\.gz$', self._name):
+        elif re.search('\.gz$', self.name):
             dc = zlib.decompressobj(16+zlib.MAX_WBITS)
             return self.decompress(lambda segment: dc.decompress(dc.unconsumed_tail + segment),
                                    size, num_retries=num_retries)
         else:
             return self.readall(size, num_retries=num_retries)
 
+    @ArvadosFileBase._before_close
     @retry_method
     def readlines(self, decompress=True, num_retries=None):
         read_func = self.readall_decompressed if decompress else self.readall
diff --git a/sdk/python/arvados/util.py b/sdk/python/arvados/util.py
index 2609f11..9950528 100644
--- a/sdk/python/arvados/util.py
+++ b/sdk/python/arvados/util.py
@@ -348,3 +348,19 @@ def list_all(fn, num_retries=0, **kwargs):
         items_available = c['items_available']
         offset = c['offset'] + len(c['items'])
     return items
+
+def splitstream(s, filename=None):
+    """splitstream(s, filename=None) -> streamname, filename
+
+    Normalize a /-separated stream path.
+    If filename is None, extract it from the end of s.
+    If no stream name is available, assume '.'.
+    """
+    if filename is not None:
+        streamname = s or '.'
+    else:
+        try:
+            streamname, filename = s.rsplit('/', 1)
+        except ValueError:  # No / in string
+            streamname, filename = '.', s
+    return streamname, filename
diff --git a/sdk/python/tests/arvados_testutil.py b/sdk/python/tests/arvados_testutil.py
index 0dbf9bc..9655f25 100644
--- a/sdk/python/tests/arvados_testutil.py
+++ b/sdk/python/tests/arvados_testutil.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 
 import errno
+import hashlib
 import httplib
 import httplib2
 import mock
@@ -24,6 +25,21 @@ def mock_responses(body, *codes, **headers):
     return mock.patch('httplib2.Http.request', side_effect=(
             (fake_httplib2_response(code, **headers), body) for code in codes))
 
+class MockStreamReader(object):
+    def __init__(self, name='.', *data):
+        self._name = name
+        self._data = ''.join(data)
+        self._data_locators = ['{}+{}'.format(hashlib.md5(d).hexdigest(),
+                                              len(d)) for d in data]
+        self.num_retries = 0
+
+    def name(self):
+        return self._name
+
+    def readfrom(self, start, size, num_retries=None):
+        return self._data[start:start + size]
+
+
 class ArvadosBaseTestCase(unittest.TestCase):
     # This class provides common utility functions for our tests.
 
diff --git a/sdk/python/tests/test_collections.py b/sdk/python/tests/test_collections.py
index 98a72f6..2d35df8 100644
--- a/sdk/python/tests/test_collections.py
+++ b/sdk/python/tests/test_collections.py
@@ -5,6 +5,7 @@
 import arvados
 import bz2
 import copy
+import hashlib
 import mock
 import os
 import pprint
@@ -347,17 +348,9 @@ class ArvadosCollectionsTest(run_test_server.TestCaseWithServers,
         self.assertEqual(arvados.locators_and_ranges(blocks, 11, 15), [['b', 15, 1, 14],
                                                                        ['c', 5, 0, 1]])
 
-    class MockStreamReader(object):
-        def __init__(self, content):
-            self.content = content
-            self.num_retries = 0
-
-        def readfrom(self, start, size, num_retries=0):
-            return self.content[start:start+size]
-
     def test_file_stream(self):
         content = 'abcdefghijklmnopqrstuvwxyz0123456789'
-        msr = self.MockStreamReader(content)
+        msr = tutil.MockStreamReader('.', content)
         segments = [[0, 10, 0],
                     [10, 15, 10],
                     [25, 5, 25]]
@@ -713,6 +706,44 @@ class CollectionReaderTestCase(unittest.TestCase, CollectionTestMixin):
             self.assertEqual('foo',
                              ''.join(f.read(9) for f in reader.all_files()))
 
+    def check_open_file(self, coll_file, stream_name, file_name, file_size):
+        self.assertFalse(coll_file.closed, "returned file is not open")
+        self.assertEqual(stream_name, coll_file.stream_name())
+        self.assertEqual(file_name, coll_file.name())
+        self.assertEqual(file_size, coll_file.size())
+
+    def test_open_collection_file_one_argument(self):
+        client = self.api_client_mock(200)
+        reader = arvados.CollectionReader(self.DEFAULT_UUID, api_client=client)
+        cfile = reader.open('./foo')
+        self.check_open_file(cfile, '.', 'foo', 3)
+
+    def test_open_collection_file_two_arguments(self):
+        client = self.api_client_mock(200)
+        reader = arvados.CollectionReader(self.DEFAULT_UUID, api_client=client)
+        cfile = reader.open('.', 'foo')
+        self.check_open_file(cfile, '.', 'foo', 3)
+
+    def test_open_deep_file(self):
+        coll_name = 'collection_with_files_in_subdir'
+        client = self.api_client_mock(200)
+        self.mock_get_collection(client, 200, coll_name)
+        reader = arvados.CollectionReader(
+            self.API_COLLECTIONS[coll_name]['uuid'], api_client=client)
+        cfile = reader.open('./subdir2/subdir3/file2_in_subdir3.txt')
+        self.check_open_file(cfile, './subdir2/subdir3', 'file2_in_subdir3.txt',
+                             32)
+
+    def test_open_nonexistent_stream(self):
+        client = self.api_client_mock(200)
+        reader = arvados.CollectionReader(self.DEFAULT_UUID, api_client=client)
+        self.assertRaises(ValueError, reader.open, './nonexistent', 'foo')
+
+    def test_open_nonexistent_file(self):
+        client = self.api_client_mock(200)
+        reader = arvados.CollectionReader(self.DEFAULT_UUID, api_client=client)
+        self.assertRaises(ValueError, reader.open, '.', 'nonexistent')
+
 
 @tutil.skip_sleep
 class CollectionWriterTestCase(unittest.TestCase, CollectionTestMixin):
@@ -751,6 +782,63 @@ class CollectionWriterTestCase(unittest.TestCase, CollectionTestMixin):
             writer.flush_data()
         self.assertEqual(self.DEFAULT_MANIFEST, writer.manifest_text())
 
+    def test_one_open(self):
+        client = self.api_client_mock()
+        writer = arvados.CollectionWriter(client)
+        with writer.open('out') as out_file:
+            self.assertEqual('.', writer.current_stream_name())
+            self.assertEqual('out', writer.current_file_name())
+            out_file.write('test data')
+            data_loc = hashlib.md5('test data').hexdigest() + '+9'
+        self.assertTrue(out_file.closed, "writer file not closed after context")
+        self.assertRaises(ValueError, out_file.write, 'extra text')
+        with self.mock_keep(data_loc, 200) as keep_mock:
+            self.assertEqual(". {} 0:9:out\n".format(data_loc),
+                             writer.manifest_text())
+
+    def test_open_writelines(self):
+        client = self.api_client_mock()
+        writer = arvados.CollectionWriter(client)
+        with writer.open('six') as out_file:
+            out_file.writelines(['12', '34', '56'])
+            data_loc = hashlib.md5('123456').hexdigest() + '+6'
+        with self.mock_keep(data_loc, 200) as keep_mock:
+            self.assertEqual(". {} 0:6:six\n".format(data_loc),
+                             writer.manifest_text())
+
+    def test_two_opens_same_stream(self):
+        client = self.api_client_mock()
+        writer = arvados.CollectionWriter(client)
+        with writer.open('.', '1') as out_file:
+            out_file.write('1st')
+        with writer.open('.', '2') as out_file:
+            out_file.write('2nd')
+        data_loc = hashlib.md5('1st2nd').hexdigest() + '+6'
+        with self.mock_keep(data_loc, 200) as keep_mock:
+            self.assertEqual(". {} 0:3:1 3:3:2\n".format(data_loc),
+                             writer.manifest_text())
+
+    def test_two_opens_two_streams(self):
+        client = self.api_client_mock()
+        writer = arvados.CollectionWriter(client)
+        with writer.open('file') as out_file:
+            out_file.write('file')
+            data_loc1 = hashlib.md5('file').hexdigest() + '+4'
+        with self.mock_keep(data_loc1, 200) as keep_mock:
+            with writer.open('./dir', 'indir') as out_file:
+                out_file.write('indir')
+                data_loc2 = hashlib.md5('indir').hexdigest() + '+5'
+        with self.mock_keep(data_loc2, 200) as keep_mock:
+            expected = ". {} 0:4:file\n./dir {} 0:5:indir\n".format(
+                data_loc1, data_loc2)
+            self.assertEqual(expected, writer.manifest_text())
+
+    def test_dup_open_fails(self):
+        client = self.api_client_mock()
+        writer = arvados.CollectionWriter(client)
+        file1 = writer.open('one')
+        self.assertRaises(arvados.errors.AssertionError, writer.open, 'two')
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/sdk/python/tests/test_stream.py b/sdk/python/tests/test_stream.py
index 3970d67..cb7e352 100644
--- a/sdk/python/tests/test_stream.py
+++ b/sdk/python/tests/test_stream.py
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 
+import os
 import mock
 import unittest
 
@@ -9,6 +10,112 @@ from arvados import StreamReader, StreamFileReader
 import arvados_testutil as tutil
 import run_test_server
 
+class StreamFileReaderTestCase(unittest.TestCase):
+    def make_count_reader(self):
+        stream = tutil.MockStreamReader('.', '01234', '34567', '67890')
+        return StreamFileReader(stream, [[1, 3, 0], [6, 3, 3], [11, 3, 6]],
+                                'count.txt')
+
+    def test_read_returns_first_block(self):
+        sfile = self.make_count_reader()
+        self.assertEqual('123', sfile.read(10))
+
+    def test_small_read(self):
+        sfile = self.make_count_reader()
+        self.assertEqual('12', sfile.read(2))
+
+    def test_successive_reads(self):
+        sfile = self.make_count_reader()
+        for expect in ['123', '456', '789', '']:
+            self.assertEqual(expect, sfile.read(10))
+
+    def test_readfrom_spans_blocks(self):
+        sfile = self.make_count_reader()
+        self.assertEqual('6789', sfile.readfrom(5, 12))
+
+    def test_small_readfrom_spanning_blocks(self):
+        sfile = self.make_count_reader()
+        self.assertEqual('2345', sfile.readfrom(1, 4))
+
+    def test_readall(self):
+        sfile = self.make_count_reader()
+        self.assertEqual('123456789', ''.join(sfile.readall()))
+
+    def test_one_arg_seek(self):
+        # Our default has been SEEK_SET since time immemorial.
+        self.test_absolute_seek([])
+
+    def test_absolute_seek(self, args=[os.SEEK_SET]):
+        sfile = self.make_count_reader()
+        sfile.seek(6, *args)
+        self.assertEqual('78', sfile.read(2))
+        sfile.seek(4, *args)
+        self.assertEqual('56', sfile.read(2))
+
+    def test_relative_seek(self):
+        sfile = self.make_count_reader()
+        self.assertEqual('12', sfile.read(2))
+        sfile.seek(2, os.SEEK_CUR)
+        self.assertEqual('56', sfile.read(2))
+
+    def test_end_seek(self):
+        sfile = self.make_count_reader()
+        sfile.seek(-6, os.SEEK_END)
+        self.assertEqual('45', sfile.read(2))
+
+    def test_seek_min_zero(self):
+        sfile = self.make_count_reader()
+        sfile.seek(-2, os.SEEK_SET)
+        self.assertEqual(0, sfile.tell())
+
+    def test_seek_max_size(self):
+        sfile = self.make_count_reader()
+        sfile.seek(2, os.SEEK_END)
+        self.assertEqual(9, sfile.tell())
+
+    def test_no_read_after_close(self):
+        sfile = self.make_count_reader()
+        sfile.close()
+        self.assertRaises(ValueError, sfile.read, 2)
+
+    def test_context(self):
+        with self.make_count_reader() as sfile:
+            self.assertFalse(sfile.closed, "reader is closed inside context")
+            self.assertEqual('12', sfile.read(2))
+        self.assertTrue(sfile.closed, "reader is open after context")
+
+    def make_newlines_reader(self):
+        stream = tutil.MockStreamReader('.', 'one\ntwo\n\nth', 'ree\nfour\n\n')
+        return StreamFileReader(stream, [[0, 11, 0], [11, 10, 11]], 'count.txt')
+
+    def check_lines(self, actual):
+        self.assertEqual(['one\n', 'two\n', '\n', 'three\n', 'four\n', '\n'],
+                         actual)
+
+    def test_readline(self):
+        reader = self.make_newlines_reader()
+        actual = []
+        while True:
+            data = reader.readline()
+            if not data:
+                break
+            actual.append(data)
+        self.check_lines(actual)
+
+    def test_readlines(self):
+        self.check_lines(list(self.make_newlines_reader().readlines()))
+
+    def test_iteration(self):
+        self.check_lines(list(iter(self.make_newlines_reader())))
+
+    def test_name_attribute(self):
+        # Test both .name and .name() (for backward compatibility)
+        stream = tutil.MockStreamReader()
+        sfile = StreamFileReader(stream, [[0, 0, 0]], 'nametest')
+        self.assertEqual('nametest', sfile.name)
+        self.assertEqual('nametest', sfile.name())
+
+
 class StreamRetryTestMixin(object):
     # Define reader_for(coll_name, **kwargs)
     # and read_for_test(reader, size, **kwargs).

commit 2fcc213423b3fff0cf5a2372fd8f4e988077db79
Author: Brett Smith <brett at curoverse.com>
Date:   Fri Oct 24 11:16:05 2014 -0400

    3603: Clean up PySDK imports.
    
    Sort; remove unused imports.

diff --git a/sdk/python/arvados/collection.py b/sdk/python/arvados/collection.py
index 20de716..dac75e0 100644
--- a/sdk/python/arvados/collection.py
+++ b/sdk/python/arvados/collection.py
@@ -1,22 +1,6 @@
-import gflags
-import httplib
-import httplib2
 import logging
 import os
-import pprint
-import sys
-import types
-import subprocess
-import json
-import UserDict
 import re
-import hashlib
-import string
-import bz2
-import zlib
-import fcntl
-import time
-import threading
 
 from collections import deque
 from stat import *
diff --git a/sdk/python/arvados/stream.py b/sdk/python/arvados/stream.py
index 04b6b81..2c893b6 100644
--- a/sdk/python/arvados/stream.py
+++ b/sdk/python/arvados/stream.py
@@ -1,22 +1,9 @@
-import gflags
-import httplib
-import httplib2
+import bz2
+import collections
+import hashlib
 import os
-import pprint
-import sys
-import types
-import subprocess
-import json
-import UserDict
 import re
-import hashlib
-import string
-import bz2
 import zlib
-import fcntl
-import time
-import threading
-import collections
 
 from arvados.retry import retry_method
 from keep import *
@@ -195,7 +182,7 @@ class StreamFileReader(object):
             data += newdata
             sol = 0
             while True:
-                eol = string.find(data, "\n", sol)
+                eol = data.find("\n", sol)
                 if eol < 0:
                     break
                 yield data[sol:eol+1]

commit 824b37d6f42ebae72b5e6fa2392d58993498bbee
Author: Brett Smith <brett at curoverse.com>
Date:   Fri Oct 24 11:23:09 2014 -0400

    3603: Fix context methods for PySDK Collection objects.

diff --git a/sdk/python/arvados/collection.py b/sdk/python/arvados/collection.py
index 782e85c..20de716 100644
--- a/sdk/python/arvados/collection.py
+++ b/sdk/python/arvados/collection.py
@@ -93,9 +93,9 @@ def normalize(collection):
 
 class CollectionBase(object):
     def __enter__(self):
-        pass
+        return self
 
-    def __exit__(self):
+    def __exit__(self, exc_type, exc_value, traceback):
         pass
 
     def _my_keep(self):
@@ -240,8 +240,9 @@ class CollectionWriter(CollectionBase):
         self._queued_dirents = deque()
         self._queued_trees = deque()
 
-    def __exit__(self):
-        self.finish()
+    def __exit__(self, exc_type, exc_value, traceback):
+        if exc_type is None:
+            self.finish()
 
     def do_queued_work(self):
         # The work queue consists of three pieces:

-----------------------------------------------------------------------


hooks/post-receive
-- 




More information about the arvados-commits mailing list