一、问题描述

给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的中位数 。

算法的时间复杂度应该为 O(log (m+n)) 。

1
2
3
4
5
6
7
8
9
10
11
示例 1:

输入:nums1 = [1,3], nums2 = [2]
输出:2.00000
解释:合并数组 = [1,2,3] ,中位数 2

示例 2:

输入:nums1 = [1,2], nums2 = [3,4]
输出:2.50000
解释:合并数组 = [1,2,3,4] ,中位数 (2 + 3) / 2 = 2.5

提示

  • nums1.length == m
  • nums2.length == n
  • 0 <= m <= 1000
  • 0 <= n <= 1000
  • 1 <= m + n <= 2000
  • -106 <= nums1[i], nums2[i] <= 106

中位数

中位数是统计学中用于描述数据集中趋势的一种指标。它表示将一组数据按大小顺序排列后,位于中间位置的数值。中位数能够有效避免极端值的影响,因此在数据分布不均匀或存在异常值时,中位数比平均值更能反映数据的典型情况。

对于一组数据:

  • 如果数据的个数是奇数,中位数是排序后位于中间的那个数。
  • 如果数据的个数是偶数,中位数是排序后中间两个数的平均值。

二、解决方案

1、合并数组之后排序

根据中位数的特性,先将两个有序数组合并,直接复制数组合并不管两个数组中元素的顺序,之后将合并后的数组进行排序,之后计算中位数即可。

1
2
3
4
5
6
7
8
9
10
11
public double mergeArrayAndSort(int[] nums1, int[] nums2) {  

int[] mergeArray = new int[nums1.length + nums2.length];

System.arraycopy(nums1, 0, mergeArray, 0, nums1.length);
System.arraycopy(nums2, 0, mergeArray, nums1.length, nums2.length);

Arrays.sort(mergeArray);

return (double) (mergeArray[mergeArray.length / 2] + mergeArray[(mergeArray.length - 1) / 2]) / 2;
}

Arrays.sort(int[] array) 方法,对指定数组的全部进行排序,快速排序算法,具有 n*log(n) 的性能。

2、合并完成时即为有序

先将两个数组合并,一边合并一边确定数组中元素的顺序,即从两个数组中按照元素大小寻找应该出现在合并数组当前位置的元素。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public double mergeSortedArray(int[] nums1, int[] nums2) {  

int[] mergeArray = new int[nums1.length + nums2.length];

int mergeIdx = 0;
int idx = 0, idx2 = 0;
while (idx < nums1.length && idx2 < nums2.length) {
if (nums1[idx] < nums2[idx2]) {
mergeArray[mergeIdx++] = nums1[idx++];
} else if (nums1[idx] > nums2[idx2]) {
mergeArray[mergeIdx++] = nums2[idx2++];
} else {
mergeArray[mergeIdx++] = nums1[idx++];
mergeArray[mergeIdx++] = nums2[idx2++];
}
}
while (idx < nums1.length) {
mergeArray[mergeIdx++] = nums1[idx++];
}
while (idx2 < nums2.length) {
mergeArray[mergeIdx++] = nums2[idx2++];
}

return (double) (mergeArray[mergeArray.length / 2] + mergeArray[(mergeArray.length - 1) / 2]) / 2;
}

三、测试程序

1、测试数据生成工具

测试数据对象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
@Data  
@NoArgsConstructor
@AllArgsConstructor
@Accessors(chain = true)
class InputOutput {

/**
* 数组 1
*/ private int[] array;

/**
* 数组 2
*/ private int[] array2;

/**
* 中位数
*/
private double median;

}

测试数据生成工具,生成为 CSV 文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class CsvTestDataGenerator {

private static final String TEST_DATA_PATH = "src/test/resources/csv/problem/NO_0004/";

/**
* <p>十万条和一百万条测试数据文件需要生成,没有提交到 git
*/
// private static final int TEST_DATA_SIZE = 1_000;
// private static final int TEST_DATA_SIZE = 10_000;
// private static final int TEST_DATA_SIZE = 100_000;
private static final int TEST_DATA_SIZE = 1_000_000;

private static final int ARRAY_MIN_LENGTH = 50;

private static final int ARRAY_MAX_LENGTH = 100;

private static final int MAX_VALUE = 10000;

public static void main(String[] args) throws FileNotFoundException {
File file = ResourceUtils.getFile(TEST_DATA_PATH + "test_data_" + TEST_DATA_SIZE + ".csv");

FindMedianSortedArraysSolution findMedianSortedArraysSolution = new FindMedianSortedArraysSolution();

CsvWriteConfig csvWriteConfig = CsvWriteConfig.defaultConfig();
try (FileWriter fileWriter = new FileWriter(file);
CsvWriter csvWriter = new CsvWriter(fileWriter, csvWriteConfig)) {

for (int i = 0; i < TEST_DATA_SIZE; i++) {
int[] array = generateSortedArray();
int[] array2 = generateSortedArray();
double median = findMedianSortedArraysSolution.mergeSortedArray(array, array2);

csvWriter.writeLine(StringUtils.join(array, ';'),
StringUtils.join(array2, ';'),
String.valueOf(median));
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}

/**
* <p>生成正序数组
* <p>数组中可能会有重复的值
*/
private static int[] generateSortedArray() {
int arrayLength = RandomUtil.randomInt(ARRAY_MIN_LENGTH, ARRAY_MAX_LENGTH + 1);
int[] array = new int[arrayLength];
for (int i = 0; i < array.length; i++) {
array[i] = RandomUtil.randomInt(0, MAX_VALUE);
}
Arrays.sort(array);
return array;
}

}

JUnit5 测试程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
@Slf4j  
class FindMedianSortedArraysSolutionTest {

private final FindMedianSortedArraysSolution findMedianSortedArraysSolution = new FindMedianSortedArraysSolution();

@Getter
private static Stream<Arguments> testData;

@BeforeAll
public static void init() {
CsvTestDataInit csvTestDataInit = new CsvTestDataInit();
testData = csvTestDataInit.getTestData().stream().map(inputOutput ->
Arguments.of(inputOutput.getArray(), inputOutput.getArray2(), inputOutput.getMedian()));
}

@ParameterizedTest
@MethodSource("getTestData")
void mergeArrayAndSort(int[] nums, int[] nums2, double result) {
double medianSortedArrays = findMedianSortedArraysSolution.mergeArrayAndSort(nums, nums2);
Assertions.assertEquals(result, medianSortedArrays);
}

@ParameterizedTest
@MethodSource("getTestData")
void mergeSortedArray(int[] nums, int[] nums2, double result) {
double medianSortedArrays = findMedianSortedArraysSolution.mergeSortedArray(nums, nums2);
Assertions.assertEquals(result, medianSortedArrays);
}

}

CSV 数据读取

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class CsvTestDataInit {  

private static final String TEST_DATA_CLASS_PATH =
ResourceUtils.CLASSPATH_URL_PREFIX + "csv/problem/NO_0004/";

public List<InputOutput> getTestData() {
return this.getTestData(1000);
}

public List<InputOutput> getTestData(int dataSize) {

CsvReadConfig csvReadConfig = CsvReadConfig.defaultConfig();
csvReadConfig.setBeginLineNo(0);

String filePath = TEST_DATA_CLASS_PATH + "test_data_" + dataSize + ".csv";
try (InputStreamReader fileReader = new FileReader(ResourceUtils.getFile(filePath));
CsvReader reader = CsvUtil.getReader(fileReader, csvReadConfig);) {

CsvData csvData = reader.read();

List<InputOutput> queryInputOutputList = csvData.getRows().parallelStream()
.map(row -> new InputOutput(StrUtil.splitToInt(row.getFirst(), ';'),
StrUtil.splitToInt(row.get(1), ';'),
Double.parseDouble(row.get(2)))
)
.toList();

boolean sortedFlag = queryInputOutputList.stream().allMatch(sortInputOutput ->
ArrayUtils.isSorted(sortInputOutput.getArray()) && ArrayUtils.isSorted(sortInputOutput.getArray2())
);
if (!sortedFlag) {
throw new RuntimeException("csv data result array is not sorted");
}
return queryInputOutputList;
} catch (IOException e) {
throw new RuntimeException(e);
}
}

}

较大批量数据测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@Slf4j  
class SpeedTest {

private final FindMedianSortedArraysSolution findMedianSortedArraysSolution = new FindMedianSortedArraysSolution();

/**
* 测试数据数据量
*/
private static final int TEST_DATA_SIZE = 1_000_000;

private static List<InputOutput> testDataList;

@BeforeAll
public static void init() {
long start, end;
start = System.currentTimeMillis();

testDataList = new CsvTestDataInit().getTestData(TEST_DATA_SIZE);

end = System.currentTimeMillis();
log.info("从 csv 文件中读取 {} 条数据用时 {} ms", testDataList.size(), end - start);
}

@Test
void mergeArrayAndSort() {
long start, end;
start = System.currentTimeMillis();

testDataList.forEach(inputOutput -> {
double result = findMedianSortedArraysSolution.mergeArrayAndSort(inputOutput.getArray(), inputOutput.getArray2());
Assertions.assertEquals(inputOutput.getMedian(), result);
});

end = System.currentTimeMillis();
log.info("处理 {} 条数据,用时 {} ms", testDataList.size(), end - start);
}

@Test
void mergeSortedArray() {
long start, end;
start = System.currentTimeMillis();

testDataList.forEach(inputOutput -> {
double result = findMedianSortedArraysSolution.mergeSortedArray(inputOutput.getArray(), inputOutput.getArray2());
Assertions.assertEquals(inputOutput.getMedian(), result);
});

end = System.currentTimeMillis();
log.info("处理 {} 条数据,用时 {} ms", testDataList.size(), end - start);
}

}

四、其他

1、计算中位数

计算中位数有两种情况,一种是数组大小为奇数,另一种是数组大小为偶数。两种计算方法分开计算代码如下:

1
2
3
4
5
6
7
int[] mergeArray = new int[nums1.length + nums2.length];
double result = 0;
if (mergeArray.length % 2 == 0) {
result = (float) ((mergeArray[mergeArray.length / 2 - 1] + mergeArray[mergeArray.length / 2]) / 2.0);
} else {
result = mergeArray[mergeArray.length / 2];
}

但是可以将两种情况合并:

1
2
int[] mergeArray = new int[nums1.length + nums2.length];
double result = (double) (mergeArray[mergeArray.length / 2] + mergeArray[(mergeArray.length - 1) / 2]) / 2;

相关链接

OB tags

#LeetCode #算法 #未完待续