from doctest import DocTestSuite from unittest import TestCase from itertools import combinations from six.moves import range import more_itertools as mi def load_tests(loader, tests, ignore): # Add the doctests tests.addTests(DocTestSuite('more_itertools.recipes')) return tests class AccumulateTests(TestCase): """Tests for ``accumulate()``""" def test_empty(self): """Test that an empty input returns an empty output""" self.assertEqual(list(mi.accumulate([])), []) def test_default(self): """Test accumulate with the default function (addition)""" self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6]) def test_bogus_function(self): """Test accumulate with an invalid function""" with self.assertRaises(TypeError): list(mi.accumulate([1, 2, 3], func=lambda x: x)) def test_custom_function(self): """Test accumulate with a custom function""" self.assertEqual( list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3] ) class TakeTests(TestCase): """Tests for ``take()``""" def test_simple_take(self): """Test basic usage""" t = mi.take(5, range(10)) self.assertEqual(t, [0, 1, 2, 3, 4]) def test_null_take(self): """Check the null case""" t = mi.take(0, range(10)) self.assertEqual(t, []) def test_negative_take(self): """Make sure taking negative items results in a ValueError""" self.assertRaises(ValueError, lambda: mi.take(-3, range(10))) def test_take_too_much(self): """Taking more than an iterator has remaining should return what the iterator has remaining. """ t = mi.take(10, range(5)) self.assertEqual(t, [0, 1, 2, 3, 4]) class TabulateTests(TestCase): """Tests for ``tabulate()``""" def test_simple_tabulate(self): """Test the happy path""" t = mi.tabulate(lambda x: x) f = tuple([next(t) for _ in range(3)]) self.assertEqual(f, (0, 1, 2)) def test_count(self): """Ensure tabulate accepts specific count""" t = mi.tabulate(lambda x: 2 * x, -1) f = (next(t), next(t), next(t)) self.assertEqual(f, (-2, 0, 2)) class TailTests(TestCase): """Tests for ``tail()``""" def test_greater(self): """Length of iterable is greater than requested tail""" self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G']) def test_equal(self): """Length of iterable is equal to the requested tail""" self.assertEqual( list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] ) def test_less(self): """Length of iterable is less than requested tail""" self.assertEqual( list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] ) class ConsumeTests(TestCase): """Tests for ``consume()``""" def test_sanity(self): """Test basic functionality""" r = (x for x in range(10)) mi.consume(r, 3) self.assertEqual(3, next(r)) def test_null_consume(self): """Check the null case""" r = (x for x in range(10)) mi.consume(r, 0) self.assertEqual(0, next(r)) def test_negative_consume(self): """Check that negative consumsion throws an error""" r = (x for x in range(10)) self.assertRaises(ValueError, lambda: mi.consume(r, -1)) def test_total_consume(self): """Check that iterator is totally consumed by default""" r = (x for x in range(10)) mi.consume(r) self.assertRaises(StopIteration, lambda: next(r)) class NthTests(TestCase): """Tests for ``nth()``""" def test_basic(self): """Make sure the nth item is returned""" l = range(10) for i, v in enumerate(l): self.assertEqual(mi.nth(l, i), v) def test_default(self): """Ensure a default value is returned when nth item not found""" l = range(3) self.assertEqual(mi.nth(l, 100, "zebra"), "zebra") def test_negative_item_raises(self): """Ensure asking for a negative item raises an exception""" self.assertRaises(ValueError, lambda: mi.nth(range(10), -3)) class AllEqualTests(TestCase): """Tests for ``all_equal()``""" def test_true(self): """Everything is equal""" self.assertTrue(mi.all_equal('aaaaaa')) self.assertTrue(mi.all_equal([0, 0, 0, 0])) def test_false(self): """Not everything is equal""" self.assertFalse(mi.all_equal('aaaaab')) self.assertFalse(mi.all_equal([0, 0, 0, 1])) def test_tricky(self): """Not everything is identical, but everything is equal""" items = [1, complex(1, 0), 1.0] self.assertTrue(mi.all_equal(items)) def test_empty(self): """Return True if the iterable is empty""" self.assertTrue(mi.all_equal('')) self.assertTrue(mi.all_equal([])) def test_one(self): """Return True if the iterable is singular""" self.assertTrue(mi.all_equal('0')) self.assertTrue(mi.all_equal([0])) class QuantifyTests(TestCase): """Tests for ``quantify()``""" def test_happy_path(self): """Make sure True count is returned""" q = [True, False, True] self.assertEqual(mi.quantify(q), 2) def test_custom_predicate(self): """Ensure non-default predicates return as expected""" q = range(10) self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5) class PadnoneTests(TestCase): """Tests for ``padnone()``""" def test_happy_path(self): """wrapper iterator should return None indefinitely""" r = range(2) p = mi.padnone(r) self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)]) class NcyclesTests(TestCase): """Tests for ``nyclces()``""" def test_happy_path(self): """cycle a sequence three times""" r = ["a", "b", "c"] n = mi.ncycles(r, 3) self.assertEqual( ["a", "b", "c", "a", "b", "c", "a", "b", "c"], list(n) ) def test_null_case(self): """asking for 0 cycles should return an empty iterator""" n = mi.ncycles(range(100), 0) self.assertRaises(StopIteration, lambda: next(n)) def test_pathalogical_case(self): """asking for negative cycles should return an empty iterator""" n = mi.ncycles(range(100), -10) self.assertRaises(StopIteration, lambda: next(n)) class DotproductTests(TestCase): """Tests for ``dotproduct()``'""" def test_happy_path(self): """simple dotproduct example""" self.assertEqual(400, mi.dotproduct([10, 10], [20, 20])) class FlattenTests(TestCase): """Tests for ``flatten()``""" def test_basic_usage(self): """ensure list of lists is flattened one level""" f = [[0, 1, 2], [3, 4, 5]] self.assertEqual(list(range(6)), list(mi.flatten(f))) def test_single_level(self): """ensure list of lists is flattened only one level""" f = [[0, [1, 2]], [[3, 4], 5]] self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f))) class RepeatfuncTests(TestCase): """Tests for ``repeatfunc()``""" def test_simple_repeat(self): """test simple repeated functions""" r = mi.repeatfunc(lambda: 5) self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)]) def test_finite_repeat(self): """ensure limited repeat when times is provided""" r = mi.repeatfunc(lambda: 5, times=5) self.assertEqual([5, 5, 5, 5, 5], list(r)) def test_added_arguments(self): """ensure arguments are applied to the function""" r = mi.repeatfunc(lambda x: x, 2, 3) self.assertEqual([3, 3], list(r)) def test_null_times(self): """repeat 0 should return an empty iterator""" r = mi.repeatfunc(range, 0, 3) self.assertRaises(StopIteration, lambda: next(r)) class PairwiseTests(TestCase): """Tests for ``pairwise()``""" def test_base_case(self): """ensure an iterable will return pairwise""" p = mi.pairwise([1, 2, 3]) self.assertEqual([(1, 2), (2, 3)], list(p)) def test_short_case(self): """ensure an empty iterator if there's not enough values to pair""" p = mi.pairwise("a") self.assertRaises(StopIteration, lambda: next(p)) class GrouperTests(TestCase): """Tests for ``grouper()``""" def test_even(self): """Test when group size divides evenly into the length of the iterable. """ self.assertEqual( list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')] ) def test_odd(self): """Test when group size does not divide evenly into the length of the iterable. """ self.assertEqual( list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)] ) def test_fill_value(self): """Test that the fill value is used to pad the final group""" self.assertEqual( list(mi.grouper(3, 'ABCDE', 'x')), [('A', 'B', 'C'), ('D', 'E', 'x')] ) class RoundrobinTests(TestCase): """Tests for ``roundrobin()``""" def test_even_groups(self): """Ensure ordered output from evenly populated iterables""" self.assertEqual( list(mi.roundrobin('ABC', [1, 2, 3], range(3))), ['A', 1, 0, 'B', 2, 1, 'C', 3, 2] ) def test_uneven_groups(self): """Ensure ordered output from unevenly populated iterables""" self.assertEqual( list(mi.roundrobin('ABCD', [1, 2], range(0))), ['A', 1, 'B', 2, 'C', 'D'] ) class PartitionTests(TestCase): """Tests for ``partition()``""" def test_bool(self): """Test when pred() returns a boolean""" lesser, greater = mi.partition(lambda x: x > 5, range(10)) self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5]) self.assertEqual(list(greater), [6, 7, 8, 9]) def test_arbitrary(self): """Test when pred() returns an integer""" divisibles, remainders = mi.partition(lambda x: x % 3, range(10)) self.assertEqual(list(divisibles), [0, 3, 6, 9]) self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8]) class PowersetTests(TestCase): """Tests for ``powerset()``""" def test_combinatorics(self): """Ensure a proper enumeration""" p = mi.powerset([1, 2, 3]) self.assertEqual( list(p), [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] ) class UniqueEverseenTests(TestCase): """Tests for ``unique_everseen()``""" def test_everseen(self): """ensure duplicate elements are ignored""" u = mi.unique_everseen('AAAABBBBCCDAABBB') self.assertEqual( ['A', 'B', 'C', 'D'], list(u) ) def test_custom_key(self): """ensure the custom key comparison works""" u = mi.unique_everseen('aAbACCc', key=str.lower) self.assertEqual(list('abC'), list(u)) def test_unhashable(self): """ensure things work for unhashable items""" iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] u = mi.unique_everseen(iterable) self.assertEqual(list(u), ['a', [1, 2, 3]]) def test_unhashable_key(self): """ensure things work for unhashable items with a custom key""" iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] u = mi.unique_everseen(iterable, key=lambda x: x) self.assertEqual(list(u), ['a', [1, 2, 3]]) class UniqueJustseenTests(TestCase): """Tests for ``unique_justseen()``""" def test_justseen(self): """ensure only last item is remembered""" u = mi.unique_justseen('AAAABBBCCDABB') self.assertEqual(list('ABCDAB'), list(u)) def test_custom_key(self): """ensure the custom key comparison works""" u = mi.unique_justseen('AABCcAD', str.lower) self.assertEqual(list('ABCAD'), list(u)) class IterExceptTests(TestCase): """Tests for ``iter_except()``""" def test_exact_exception(self): """ensure the exact specified exception is caught""" l = [1, 2, 3] i = mi.iter_except(l.pop, IndexError) self.assertEqual(list(i), [3, 2, 1]) def test_generic_exception(self): """ensure the generic exception can be caught""" l = [1, 2] i = mi.iter_except(l.pop, Exception) self.assertEqual(list(i), [2, 1]) def test_uncaught_exception_is_raised(self): """ensure a non-specified exception is raised""" l = [1, 2, 3] i = mi.iter_except(l.pop, KeyError) self.assertRaises(IndexError, lambda: list(i)) def test_first(self): """ensure first is run before the function""" l = [1, 2, 3] f = lambda: 25 i = mi.iter_except(l.pop, IndexError, f) self.assertEqual(list(i), [25, 3, 2, 1]) class FirstTrueTests(TestCase): """Tests for ``first_true()``""" def test_something_true(self): """Test with no keywords""" self.assertEqual(mi.first_true(range(10)), 1) def test_nothing_true(self): """Test default return value.""" self.assertIsNone(mi.first_true([0, 0, 0])) def test_default(self): """Test with a default keyword""" self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!') def test_pred(self): """Test with a custom predicate""" self.assertEqual( mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6 ) class RandomProductTests(TestCase): """Tests for ``random_product()`` Since random.choice() has different results with the same seed across python versions 2.x and 3.x, these tests use highly probably events to create predictable outcomes across platforms. """ def test_simple_lists(self): """Ensure that one item is chosen from each list in each pair. Also ensure that each item from each list eventually appears in the chosen combinations. Odds are roughly 1 in 7.1 * 10e16 that one item from either list will not be chosen after 100 samplings of one item from each list. Just to be safe, better use a known random seed, too. """ nums = [1, 2, 3] lets = ['a', 'b', 'c'] n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)]) n, m = set(n), set(m) self.assertEqual(n, set(nums)) self.assertEqual(m, set(lets)) self.assertEqual(len(n), len(nums)) self.assertEqual(len(m), len(lets)) def test_list_with_repeat(self): """ensure multiple items are chosen, and that they appear to be chosen from one list then the next, in proper order. """ nums = [1, 2, 3] lets = ['a', 'b', 'c'] r = list(mi.random_product(nums, lets, repeat=100)) self.assertEqual(2 * 100, len(r)) n, m = set(r[::2]), set(r[1::2]) self.assertEqual(n, set(nums)) self.assertEqual(m, set(lets)) self.assertEqual(len(n), len(nums)) self.assertEqual(len(m), len(lets)) class RandomPermutationTests(TestCase): """Tests for ``random_permutation()``""" def test_full_permutation(self): """ensure every item from the iterable is returned in a new ordering 15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so we fix a seed value just to be sure. """ i = range(15) r = mi.random_permutation(i) self.assertEqual(set(i), set(r)) if i == r: raise AssertionError("Values were not permuted") def test_partial_permutation(self): """ensure all returned items are from the iterable, that the returned permutation is of the desired length, and that all items eventually get returned. Sampling 100 permutations of length 5 from a set of 15 leaves a (2/3)^100 chance that an item will not be chosen. Multiplied by 15 items, there is a 1 in 2.6e16 chance that at least 1 item will not show up in the resulting output. Using a random seed will fix that. """ items = range(15) item_set = set(items) all_items = set() for _ in range(100): permutation = mi.random_permutation(items, 5) self.assertEqual(len(permutation), 5) permutation_set = set(permutation) self.assertLessEqual(permutation_set, item_set) all_items |= permutation_set self.assertEqual(all_items, item_set) class RandomCombinationTests(TestCase): """Tests for ``random_combination()``""" def test_pseudorandomness(self): """ensure different subsets of the iterable get returned over many samplings of random combinations""" items = range(15) all_items = set() for _ in range(50): combination = mi.random_combination(items, 5) all_items |= set(combination) self.assertEqual(all_items, set(items)) def test_no_replacement(self): """ensure that elements are sampled without replacement""" items = range(15) for _ in range(50): combination = mi.random_combination(items, len(items)) self.assertEqual(len(combination), len(set(combination))) self.assertRaises( ValueError, lambda: mi.random_combination(items, len(items) + 1) ) class RandomCombinationWithReplacementTests(TestCase): """Tests for ``random_combination_with_replacement()``""" def test_replacement(self): """ensure that elements are sampled with replacement""" items = range(5) combo = mi.random_combination_with_replacement(items, len(items) * 2) self.assertEqual(2 * len(items), len(combo)) if len(set(combo)) == len(combo): raise AssertionError("Combination contained no duplicates") def test_pseudorandomness(self): """ensure different subsets of the iterable get returned over many samplings of random combinations""" items = range(15) all_items = set() for _ in range(50): combination = mi.random_combination_with_replacement(items, 5) all_items |= set(combination) self.assertEqual(all_items, set(items)) class NthCombinationTests(TestCase): def test_basic(self): iterable = 'abcdefg' r = 4 for index, expected in enumerate(combinations(iterable, r)): actual = mi.nth_combination(iterable, r, index) self.assertEqual(actual, expected) def test_long(self): actual = mi.nth_combination(range(180), 4, 2000000) expected = (2, 12, 35, 126) self.assertEqual(actual, expected) def test_invalid_r(self): for r in (-1, 3): with self.assertRaises(ValueError): mi.nth_combination([], r, 0) def test_invalid_index(self): with self.assertRaises(IndexError): mi.nth_combination('abcdefg', 3, -36) class PrependTests(TestCase): def test_basic(self): value = 'a' iterator = iter('bcdefg') actual = list(mi.prepend(value, iterator)) expected = list('abcdefg') self.assertEqual(actual, expected) def test_multiple(self): value = 'ab' iterator = iter('cdefg') actual = tuple(mi.prepend(value, iterator)) expected = ('ab',) + tuple('cdefg') self.assertEqual(actual, expected)