sample.pass.cpp 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. //===----------------------------------------------------------------------===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. // UNSUPPORTED: c++98, c++03, c++11, c++14
  9. // <algorithm>
  10. // template <class PopulationIterator, class SampleIterator, class Distance,
  11. // class UniformRandomNumberGenerator>
  12. // SampleIterator sample(PopulationIterator first, PopulationIterator last,
  13. // SampleIterator out, Distance n,
  14. // UniformRandomNumberGenerator &&g);
  15. #include <algorithm>
  16. #include <random>
  17. #include <type_traits>
  18. #include <cassert>
  19. #include <cstddef>
  20. #include "test_iterators.h"
  21. #include "test_macros.h"
  22. struct ReservoirSampleExpectations {
  23. enum { os = 4 };
  24. static int oa1[os];
  25. static int oa2[os];
  26. };
  27. int ReservoirSampleExpectations::oa1[] = {10, 5, 9, 4};
  28. int ReservoirSampleExpectations::oa2[] = {5, 2, 10, 4};
  29. struct SelectionSampleExpectations {
  30. enum { os = 4 };
  31. static int oa1[os];
  32. static int oa2[os];
  33. };
  34. int SelectionSampleExpectations::oa1[] = {1, 4, 6, 7};
  35. int SelectionSampleExpectations::oa2[] = {1, 2, 6, 8};
  36. template <class IteratorCategory> struct TestExpectations
  37. : public SelectionSampleExpectations {};
  38. template <>
  39. struct TestExpectations<std::input_iterator_tag>
  40. : public ReservoirSampleExpectations {};
  41. template <template<class...> class PopulationIteratorType, class PopulationItem,
  42. template<class...> class SampleIteratorType, class SampleItem>
  43. void test() {
  44. typedef PopulationIteratorType<PopulationItem *> PopulationIterator;
  45. typedef SampleIteratorType<SampleItem *> SampleIterator;
  46. PopulationItem ia[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
  47. const unsigned is = sizeof(ia) / sizeof(ia[0]);
  48. typedef TestExpectations<typename std::iterator_traits<
  49. PopulationIterator>::iterator_category> Expectations;
  50. const unsigned os = Expectations::os;
  51. SampleItem oa[os];
  52. const int *oa1 = Expectations::oa1;
  53. ((void)oa1); // Prevent unused warning
  54. const int *oa2 = Expectations::oa2;
  55. ((void)oa2); // Prevent unused warning
  56. std::minstd_rand g;
  57. SampleIterator end;
  58. end = std::sample(PopulationIterator(ia),
  59. PopulationIterator(ia + is),
  60. SampleIterator(oa), os, g);
  61. assert(static_cast<std::size_t>(end.base() - oa) == std::min(os, is));
  62. // sample() is deterministic but non-reproducible;
  63. // its results can vary between implementations.
  64. LIBCPP_ASSERT(std::equal(oa, oa + os, oa1));
  65. end = std::sample(PopulationIterator(ia),
  66. PopulationIterator(ia + is),
  67. SampleIterator(oa), os, std::move(g));
  68. assert(static_cast<std::size_t>(end.base() - oa) == std::min(os, is));
  69. LIBCPP_ASSERT(std::equal(oa, oa + os, oa2));
  70. }
  71. template <template<class...> class PopulationIteratorType, class PopulationItem,
  72. template<class...> class SampleIteratorType, class SampleItem>
  73. void test_empty_population() {
  74. typedef PopulationIteratorType<PopulationItem *> PopulationIterator;
  75. typedef SampleIteratorType<SampleItem *> SampleIterator;
  76. PopulationItem ia[] = {42};
  77. const unsigned os = 4;
  78. SampleItem oa[os];
  79. std::minstd_rand g;
  80. SampleIterator end =
  81. std::sample(PopulationIterator(ia), PopulationIterator(ia),
  82. SampleIterator(oa), os, g);
  83. assert(end.base() == oa);
  84. }
  85. template <template<class...> class PopulationIteratorType, class PopulationItem,
  86. template<class...> class SampleIteratorType, class SampleItem>
  87. void test_empty_sample() {
  88. typedef PopulationIteratorType<PopulationItem *> PopulationIterator;
  89. typedef SampleIteratorType<SampleItem *> SampleIterator;
  90. PopulationItem ia[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
  91. const unsigned is = sizeof(ia) / sizeof(ia[0]);
  92. SampleItem oa[1];
  93. std::minstd_rand g;
  94. SampleIterator end =
  95. std::sample(PopulationIterator(ia), PopulationIterator(ia + is),
  96. SampleIterator(oa), 0, g);
  97. assert(end.base() == oa);
  98. }
  99. template <template<class...> class PopulationIteratorType, class PopulationItem,
  100. template<class...> class SampleIteratorType, class SampleItem>
  101. void test_small_population() {
  102. // The population size is less than the sample size.
  103. typedef PopulationIteratorType<PopulationItem *> PopulationIterator;
  104. typedef SampleIteratorType<SampleItem *> SampleIterator;
  105. PopulationItem ia[] = {1, 2, 3, 4, 5};
  106. const unsigned is = sizeof(ia) / sizeof(ia[0]);
  107. const unsigned os = 8;
  108. SampleItem oa[os];
  109. const SampleItem oa1[] = {1, 2, 3, 4, 5};
  110. std::minstd_rand g;
  111. SampleIterator end;
  112. end = std::sample(PopulationIterator(ia),
  113. PopulationIterator(ia + is),
  114. SampleIterator(oa), os, g);
  115. assert(static_cast<std::size_t>(end.base() - oa) == std::min(os, is));
  116. typedef typename std::iterator_traits<PopulationIterator>::iterator_category PopulationCategory;
  117. if (std::is_base_of<std::forward_iterator_tag, PopulationCategory>::value) {
  118. assert(std::equal(oa, end.base(), oa1));
  119. } else {
  120. assert(std::is_permutation(oa, end.base(), oa1));
  121. }
  122. }
  123. int main(int, char**) {
  124. test<input_iterator, int, random_access_iterator, int>();
  125. test<forward_iterator, int, output_iterator, int>();
  126. test<forward_iterator, int, random_access_iterator, int>();
  127. test<input_iterator, int, random_access_iterator, double>();
  128. test<forward_iterator, int, output_iterator, double>();
  129. test<forward_iterator, int, random_access_iterator, double>();
  130. test_empty_population<input_iterator, int, random_access_iterator, int>();
  131. test_empty_population<forward_iterator, int, output_iterator, int>();
  132. test_empty_population<forward_iterator, int, random_access_iterator, int>();
  133. test_empty_sample<input_iterator, int, random_access_iterator, int>();
  134. test_empty_sample<forward_iterator, int, output_iterator, int>();
  135. test_empty_sample<forward_iterator, int, random_access_iterator, int>();
  136. test_small_population<input_iterator, int, random_access_iterator, int>();
  137. test_small_population<forward_iterator, int, output_iterator, int>();
  138. test_small_population<forward_iterator, int, random_access_iterator, int>();
  139. return 0;
  140. }