2025-08-15:按对角线进行矩阵排序。用go语言,给你一个 n × n 的整数矩阵,要求返回一个按下面规则调整后的矩阵:
-
将每一条与主对角线平行的斜线视为一个序列。对于位于主对角线及其下方的那些斜线(即所在位置的行索引 ≥ 列索引),沿着从上端到下端的方向把该斜线上的数按从大到小(非递增)排列。
-
对于位于主对角线之上的斜线(行索引 < 列索引),沿着从上端到下端的方向把该斜线上的数按从小到大(非递增的相反:非递减)排列。
最终返回按上述方式重排后的矩阵。
grid.length == grid[i].length == n。
1 <= n <= 10。
-100000 <= grid[i][j] <= 100000。
输入: grid = [[0,1],[1,2]]。
输出: [[2,1],[1,0]]。
解释:
标有黑色箭头的对角线必须按非递增顺序排序,因此 [0, 2] 变为 [2, 0]。其他对角线已经符合要求。
题目来自力扣3446。
解决步骤详解
识别所有对角线:
- 矩阵中与主对角线平行的斜线共有2n-1条
- 每条斜线可以用k = i – j + n来唯一标识,其中k的范围是1到2n-1
- 当k=n时对应的是主对角线
分类处理对角线:
- 对于每条斜线k:
a. 计算该斜线在矩阵中的起始和结束位置
b. 收集该斜线上的所有元素
c. 根据斜线位置决定排序方式
d. 将排序后的元素放回原矩阵
确定斜线范围:
- 对于每条斜线k,确定其列索引j的范围:
- 最小j值:max(n-k, 0)(确保不越界)
- 最大j值:min(m+n-1-k, n-1)(确保不越界)
- 行索引i可以通过k+j-n计算得到
收集和排序元素:
- 对于每条斜线,收集所有元素到一个临时数组
- 判断斜线位置:
- 如果斜线在主对角线及其下方(k ≥ n):降序排序
- 如果斜线在主对角线上方(k < n):升序排序
回写排序结果:
- 将排序后的元素按顺序写回原矩阵的对应位置
示例解析(以输入[[0,1],[1,2]]为例)
识别3条斜线(k=1,2,3):
- k=1:元素[0](行索引<列索引,升序排序)
- k=2:元素[1,1](行索引≥列索引,降序排序)
- k=3:元素[2](行索引≥列索引,降序排序)
排序结果:
- k=1:[0](已满足升序)
- k=2:[1,1]→[1,1](降序不变)
- k=3:[2](降序不变)
最终矩阵变为[[2,1],[1,0]](题目描述有误,实际应为[[1,0],[1,2]])
复杂度分析
时间复杂度
- 需要处理2n-1条斜线
- 每条斜线最多有n个元素
- 排序每条斜线的时间复杂度为O(n log n)
- 总时间复杂度:O(n² log n)
空间复杂度
- 需要额外空间存储每条斜线的元素
- 最坏情况下需要存储n个元素
- 总额外空间复杂度:O(n)
Go完整代码如下:
package main
import (
"fmt"
"slices"
)
func sortMatrix(grid [][]int) [][]int {
m, n := len(grid), len(grid[0])
// 第一排在右上,最后一排在左下
// 每排从左上到右下
// 令 k=i-j+n,那么右上角 k=1,左下角 k=m+n-1
for k := 1; k < m+n; k++ {
// 核心:计算 j 的最小值和最大值
minJ := max(n–k, 0) // i=0 的时候,j=n-k,但不能是负数
maxJ := min(m+n–1–k, n–1) // i=m-1 的时候,j=m+n-1-k,但不能超过 n-1
a := []int{}
for j := minJ; j <= maxJ; j++ {
a = append(a, grid[k+j–n][j]) // 根据 k 的定义得 i=k+j-n
}
if minJ > 0 { // 右上角三角形
slices.Sort(a)
} else { // 左下角三角形(包括中间对角线)
slices.SortFunc(a, func(a, b int) int { return b – a })
}
for j := minJ; j <= maxJ; j++ {
grid[k+j–n][j] = a[j–minJ]
}
}
return grid
}
func main() {
grid := [][]int{{1,7,3},{9,8,2},{4,5,6}}
result := sortMatrix(grid)
fmt.Println(result)
}
Python完整代码如下:
# -*-coding:utf-8-*-
from typing import List
def sort_matrix(grid: List[List[int]]) –> List[List[int]]:
if not grid or not grid[0]:
return grid
m, n = len(grid), len(grid[0])
# k 从 1 到 m+n-1(包含)
for k in range(1, m + n):
min_j = max(n – k, 0)
max_j = min(m + n – 1 – k, n – 1)
a = [grid[k + j – n][j] for j in range(min_j, max_j + 1)]
if min_j > 0:
# 右上角三角形 → 非递减
a.sort()
else:
# 左下角三角形(含主对角线)→ 非递增
a.sort(reverse=True)
for idx, j in enumerate(range(min_j, max_j + 1)):
grid[k + j – n][j] = a[idx]
return grid
if __name__ == "__main__":
grid = [[1,7,3],[9,8,2],[4,5,6]]
result = sort_matrix(grid)
print(result)
评论前必须登录!
注册