动态数组代码实现
几个关键点
下面我会直接给出一个简单的动态数组代码实现,包含了基本的增删查改功能。这里先给出几个关键点,等会你看代码的时候可以着重注意一下。
关键点一、自动扩缩容
在上一章 数组基础 中只提到了数组添加元素时可能需要扩容,并没有提到缩容。
在实际使用动态数组时,缩容也是重要的优化手段。比方说一个动态数组开辟了能够存储 1000 个元素的连续内存空间,但是实际只存了 10 个元素,那就有 990 个空间是空闲的。为了避免资源浪费,我们其实可以适当缩小存储空间,这就是缩容。
我们这里就实现一个简单的扩缩容的策略:
- 当数组元素个数达到底层静态数组的容量上限时,扩容为原来的 2 倍;
- 当数组元素个数缩减到底层静态数组的容量的 1/4 时,缩容为原来的 1/2。
关键点二、索引越界的检查
下面的代码实现中,有两个检查越界的方法,分别是 checkElementIndex
和 checkPositionIndex
,你可以看到它俩的区别仅仅在于 index < size
和 index <= size
。
为什么 checkPositionIndex
可以允许 index == size
呢,因为这个 checkPositionIndex
是专门用来处理在数组中插入元素的情况。
比方说有这样一个 nums
数组,对于每个元素来说,合法的索引一定是 index < size
:
nums = [5, 6, 7, 8]
index 0 1 2 3
但如果是要在数组中插入新元素,那么新元素可能的插入位置并不是元素的索引,而是索引之间的空隙:
nums = [ | 5 | 6 | 7 | 8 | ]
index 0 1 2 3 4
这些空隙都是合法的插入位置,所以说 index == size
也是合法的。这就是 checkPositionIndex
和 checkElementIndex
的区别。
关键点三、删除元素谨防内存泄漏
单从算法的角度,其实并不需要关心被删掉的元素应该如何处理,但是具体到代码实现,我们需要注意可能出现的内存泄漏。
在我给出的代码实现中,删除元素时,我都会把被删除的元素置为 null
,以 Java 为例:
// 删
public E removeLast() {
E deletedVal = data[size - 1];
// 删除最后一个元素
// 必须给最后一个元素置为 null,否则会内存泄漏
data[size - 1] = null;
size--;
return deletedVal;
}
Java 的垃圾回收机制是基于 图算法 的可达性分析,如果一个对象再也无法被访问到,那么这个对象占用的内存才会被释放;否则,垃圾回收器会认为这个对象还在使用中,就不会释放这个对象占用的内存。
如果你不执行 data[size - 1] = null
这行代码,那么 data[size - 1]
这个引用就会一直存在,你可以通过 data[size - 1]
访问这个对象,所以这个对象被认为是可达的,它的内存就一直不会被释放,进而造成内存泄漏。
其他带垃圾回收功能的语言应该也是类似的,你可以具体了解一下你使用的编程语言的垃圾回收机制,这是写出无 bug 代码的基本要求。
其他细节优化
下面的代码当然不会是一个很完善的实现,会有不少可以进一步优化的点。比方说,我是用 for 循环复制数组数据的,实际上这种方式复制的效率比较差,大部分编程语言会提供更高效的数组复制方法,比如 Java 的 System.arraycopy
。
不过它再怎么优化,本质上也是要搬移数据,时间复杂度都是 。本文的重点在于让你理解数组增删查改 API 的基本实现思路以及时间复杂度,如果对这些细节感兴趣,可以找到编程语言标准库的源码深入研究。
如何验证你的实现?
你可以借助力扣第 707 题「设计链表」来验证自己的实现是否正确。虽然这道题是关于链表的,但是它其实也不知道你底层到底是不是用链表实现的。咱主要是借用它的测试用例,来验证你的增删查改功能是否正确。
动态数组代码实现
import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;
public class MyArrayList<E> {
// 真正存储数据的底层数组
private E[] data;
// 记录当前元素个数
private int size;
// 默认初始容量
private static final int INIT_CAP = 1;
public MyArrayList() {
this(INIT_CAP);
}
public MyArrayList(int initCapacity) {
data = (E[]) new Object[initCapacity];
size = 0;
}
// 增
public void addLast(E e) {
int cap = data.length;
// 看 data 数组容量够不够
if (size == cap) {
resize(2 * cap);
}
// 在尾部插入元素
data[size] = e;
size++;
}
public void add(int index, E e) {
// 检查索引越界
checkPositionIndex(index);
int cap = data.length;
// 看 data 数组容量够不够
if (size == cap) {
resize(2 * cap);
}
// 搬移数据 data[index..] -> data[index+1..]
// 给新元素腾出位置
for (int i = size - 1; i >= index; i--) {
data[i + 1] = data[i];
}
// 插入新元素
data[index] = e;
size++;
}
public void addFirst(E e) {
add(0, e);
}
// 删
public E removeLast() {
if (size == 0) {
throw new NoSuchElementException();
}
int cap = data.length;
// 可以缩容,节约空间
if (size == cap / 4) {
resize(cap / 2);
}
E deletedVal = data[size - 1];
// 删除最后一个元素
// 必须给最后一个元素置为 null,否则会内存泄漏
data[size - 1] = null;
size--;
return deletedVal;
}
public E remove(int index) {
// 检查索引越界
checkElementIndex(index);
int cap = data.length;
// 可以缩容,节约空间
if (size == cap / 4) {
resize(cap / 2);
}
E deletedVal = data[index];
// 搬移数据 data[index+1..] -> data[index..]
for (int i = index + 1; i < size; i++) {
data[i - 1] = data[i];
}
data[size - 1] = null;
size--;
return deletedVal;
}
public E removeFirst() {
return remove(0);
}
// 查
public E get(int index) {
// 检查索引越界
checkElementIndex(index);
return data[index];
}
// 改
public E set(int index, E element) {
// 检查索引越界
checkElementIndex(index);
// 修改数据
E oldVal = data[index];
data[index] = element;
return oldVal;
}
// 工具方法
public int size() {
return size;
}
public boolean isEmpty() {
return size == 0;
}
// 将 data 的容量改为 newCap
private void resize(int newCap) {
E[] temp = (E[]) new Object[newCap];
for (int i = 0; i < size; i++) {
temp[i] = data[i];
}
data = temp;
}
private boolean isElementIndex(int index) {
return index >= 0 && index < size;
}
private boolean isPositionIndex(int index) {
return index >= 0 && index <= size;
}
// 检查 index 索引位置是否可以存在元素
private void checkElementIndex(int index) {
if (!isElementIndex(index))
throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
}
// 检查 index 索引位置是否可以添加元素
private void checkPositionIndex(int index) {
if (!isPositionIndex(index))
throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
}
private void display() {
System.out.println("size = " + size + " cap = " + data.length);
System.out.println(Arrays.toString(data));
}
public static void main(String[] args) {
// 初始容量设置为 3
MyArrayList<Integer> arr = new MyArrayList<>(3);
// 添加 5 个元素
for (int i = 1; i <= 5; i++) {
arr.addLast(i);
}
arr.remove(3);
arr.add(1, 9);
arr.addFirst(100);
int val = arr.removeLast();
for (int i = 0; i < arr.size(); i++) {
System.out.println(arr.get(i));
}
}
}
#include <iostream>
#include <stdexcept>
#include <vector>
template<typename E>
class MyArrayList {
private:
// 真正存储数据的底层数组
E* data;
// 记录当前元素个数
int size;
// 默认初始容量
static const int INIT_CAP = 1;
public:
MyArrayList() {
this->data = new E[INIT_CAP];
this->size = 0;
}
MyArrayList(int initCapacity) {
this->data = new E[initCapacity];
this->size = 0;
}
// 增
void addLast(E e) {
int cap = sizeof(data) / sizeof(data[0]);
// 看 data 数组容量够不够
if (size == cap) {
resize(2 * cap);
}
// 在尾部插入元素
data[size] = e;
size++;
}
void add(int index, E e) {
// 检查索引越界
checkPositionIndex(index);
int cap = sizeof(data) / sizeof(data[0]);
// 看 data 数组容量够不够
if (size == cap) {
resize(2 * cap);
}
// 搬移数据 data[index..] -> data[index+1..]
// 给新元素腾出位置
for (int i = size - 1; i >= index; i--) {
data[i + 1] = data[i];
}
// 插入新元素
data[index] = e;
size++;
}
void addFirst(E e) {
add(0, e);
}
// 删
E removeLast() {
if (size == 0) {
throw std::out_of_range("NoSuchElementException");
}
int cap = sizeof(data) / sizeof(data[0]);
// 可以缩容,节约空间
if (size == cap / 4) {
resize(cap / 2);
}
E deletedVal = data[size - 1];
// 删除最后一个元素
// 必须给最后一个元素置为 null,否则会内存泄漏
data[size - 1] = NULL;
size--;
return deletedVal;
}
E remove(int index) {
// 检查索引越界
checkElementIndex(index);
int cap = sizeof(data) / sizeof(data[0]);
// 可以缩容,节约空间
if (size == cap / 4) {
resize(cap / 2);
}
E deletedVal = data[index];
// 搬移数据 data[index+1..] -> data[index..]
for (int i = index + 1; i < size; i++) {
data[i - 1] = data[i];
}
data[size - 1] = NULL;
size--;
return deletedVal;
}
E removeFirst() {
return remove(0);
}
// 查
E get(int index) {
// 检查索引越界
checkElementIndex(index);
return data[index];
}
// 改
E set(int index, E element) {
// 检查索引越界
checkElementIndex(index);
// 修改数据
E oldVal = data[index];
data[index] = element;
return oldVal;
}
// 工具方法
int getSize() {
return size;
}
bool isEmpty() {
return size == 0;
}
// 将 data 的容量改为 newCap
void resize(int newCap) {
E* temp = new E[newCap];
for (int i = 0; i < size; i++) {
temp[i] = data[i];
}
// 释放原数组内存
delete[] data;
data = temp;
}
bool isElementIndex(int index) {
return index >= 0 && index < size;
}
bool isPositionIndex(int index) {
return index >= 0 && index <= size;
}
// 检查 index 索引位置是否可以存在元素
void checkElementIndex(int index) {
if (!isElementIndex(index)) {
throw std::out_of_range("Index out of bounds");
}
}
// 检查 index 索引位置是否可以添加元素
void checkPositionIndex(int index) {
if (!isPositionIndex(index)) {
throw std::out_of_range("Index out of bounds");
}
}
void display() {
std::cout << "size = " << size << " cap = " << sizeof(data) / sizeof(data[0]) << std::endl;
for (int i = 0; i < size; i++) {
std::cout << data[i] << " ";
}
std::cout << std::endl;
}
~MyArrayList() {
delete[] data;
}
};
int main() {
// 初始容量设置为 3
MyArrayList<int> arr(3);
// 添加 5 个元素
for (int i = 1; i <= 5; i++) {
arr.addLast(i);
}
arr.remove(3);
arr.add(1, 9);
arr.addFirst(100);
int val = arr.removeLast();
// 100 1 9 2 3
for (int i = 0; i < arr.getSize(); i++) {
std::cout << arr.get(i) << std::endl;
}
return 0;
}
class MyArrayList:
# 默认初始容量
INIT_CAP = 1
def __init__(self, init_capacity=None):
self.data = [None] * (init_capacity if init_capacity is not None else self.__class__.INIT_CAP)
self.size = 0
# 增
def add_last(self, e):
cap = len(self.data)
# 看 data 数组容量够不够
if self.size == cap:
self._resize(2 * cap)
# 在尾部插入元素
self.data[self.size] = e
self.size += 1
def add(self, index, e):
# 检查索引越界
self._check_position_index(index)
cap = len(self.data)
# 看 data 数组容量够不够
if self.size == cap:
self._resize(2 * cap)
# 搬移数据 data[index..] -> data[index+1..]
# 给新元素腾出位置
for i in range(self.size-1, index-1, -1):
self.data[i+1] = self.data[i]
# 插入新元素
self.data[index] = e
self.size += 1
def add_first(self, e):
self.add(0, e)
# 删
def remove_last(self):
if self.size == 0:
raise NoSuchElementException
cap = len(self.data)
# 可以缩容,节约空间
if self.size == cap // 4:
self._resize(cap // 2)
deleted_val = self.data[self.size - 1]
# 删除最后一个元素
self.data[self.size - 1] = None
self.size -= 1
return deleted_val
def remove(self, index):
# 检查索引越界
self._check_element_index(index)
cap = len(self.data)
# 可以缩容,节约空间
if self.size == cap // 4:
self._resize(cap // 2)
deleted_val = self.data[index]
# 搬移数据 data[index+1..] -> data[index..]
for i in range(index + 1, self.size):
self.data[i - 1] = self.data[i]
self.data[self.size - 1] = None
self.size -= 1
return deleted_val
def remove_first(self):
return self.remove(0)
# 查
def get(self, index):
# 检查索引越界
self._check_element_index(index)
return self.data[index]
# 改
def set(self, index, element):
# 检查索引越界
self._check_element_index(index)
# 修改数据
old_val = self.data[index]
self.data[index] = element
return old_val
# 工具方法
def size(self):
return self.size
def is_empty(self):
return self.size == 0
# 将 data 的容量改为 newCap
def _resize(self, new_cap):
temp = [None] * new_cap
for i in range(self.size):
temp[i] = self.data[i]
self.data = temp
def _is_element_index(self, index):
return 0 <= index < self.size
def _is_position_index(self, index):
return 0 <= index <= self.size
def _check_element_index(self, index):
if not self._is_element_index(index):
raise IndexError(f"Index: {index}, Size: {self.size}")
def _check_position_index(self, index):
if not self._is_position_index(index):
raise IndexError(f"Index: {index}, Size: {self.size}")
def display(self):
print(f"size = {self.size}, cap = {len(self.data)}")
print(self.data)
# Usage example
if __name__ == "__main__":
arr = MyArrayList(init_capacity=3)
# 添加 5 个元素
for i in range(1, 6):
arr.add_last(i)
arr.remove(3)
arr.add(1, 9)
arr.add_first(100)
val = arr.remove_last()
# 100 1 9 2 3
for i in range(arr.size):
print(arr.get(i))
package main
import (
"errors"
"fmt"
)
type MyArrayList struct {
// 真正存储数据的底层数组
data []interface{}
// 记录当前元素个数
size int
}
const INIT_CAP = 1
func NewMyArrayList() *MyArrayList {
return NewMyArrayListWithCapacity(INIT_CAP)
}
func NewMyArrayListWithCapacity(initCapacity int) *MyArrayList {
return &MyArrayList{
data: make([]interface{}, initCapacity),
size: 0,
}
}
// 增
func (list *MyArrayList) AddLast(value interface{}) {
cap := len(list.data)
// 看 data 数组容量够不够
if list.size == cap {
list.resize(2 * cap)
}
// 在尾部插入元素
list.data[list.size] = value
list.size++
}
func (list *MyArrayList) Add(index int, value interface{}) error {
// 检查索引越界
if err := list.checkPositionIndex(index); err != nil {
return err
}
cap := len(list.data)
// 看 data 数组容量够不够
if list.size == cap {
list.resize(2 * cap)
}
// 搬移数据 data[index..] -> data[index+1..]
// 给新元素腾出位置
for i := list.size - 1; i >= index; i-- {
list.data[i+1] = list.data[i]
}
// 插入新元素
list.data[index] = value
list.size++
return nil
}
func (list *MyArrayList) AddFirst(value interface{}) error {
return list.Add(0, value)
}
// 删
func (list *MyArrayList) RemoveLast() (interface{}, error) {
if list.size == 0 {
return nil, errors.New("No such element")
}
cap := len(list.data)
// 可以缩容,节约空间
if list.size == cap/4 {
list.resize(cap / 2)
}
deletedVal := list.data[list.size-1]
// 删除最后一个元素
// 必须给最后一个元素置为 nil,否则会内存泄漏
list.data[list.size-1] = nil
list.size--
return deletedVal, nil
}
func (list *MyArrayList) Remove(index int) (interface{}, error) {
// 检查索引越界
if err := list.checkElementIndex(index); err != nil {
return nil, err
}
cap := len(list.data)
// 可以缩容,节约空间
if list.size == cap/4 {
list.resize(cap / 2)
}
deletedVal := list.data[index]
// 搬移数据 data[index+1..] -> data[index..]
for i := index + 1; i < list.size; i++ {
list.data[i-1] = list.data[i]
}
list.data[list.size-1] = nil
list.size--
return deletedVal, nil
}
func (list *MyArrayList) RemoveFirst() (interface{}, error) {
return list.Remove(0)
}
// 查
func (list *MyArrayList) Get(index int) (interface{}, error) {
// 检查索引越界
if err := list.checkElementIndex(index); err != nil {
return nil, err
}
return list.data[index], nil
}
// 改
func (list *MyArrayList) Set(index int, value interface{}) (interface{}, error) {
// 检查索引越界
if err := list.checkElementIndex(index); err != nil {
return nil, err
}
// 修改数据
oldVal := list.data[index]
list.data[index] = value
return oldVal, nil
}
// 工具方法
func (list *MyArrayList) Size() int {
return list.size
}
func (list *MyArrayList) IsEmpty() bool {
return list.size == 0
}
// 将 data 的容量改为 newCap
func (list *MyArrayList) resize(newCap int) {
temp := make([]interface{}, newCap)
for i := 0; i < list.size; i++ {
temp[i] = list.data[i]
}
list.data = temp
}
func (list *MyArrayList) isElementIndex(index int) bool {
return index >= 0 && index < list.size
}
func (list *MyArrayList) isPositionIndex(index int) bool {
return index >= 0 && index <= list.size
}
// 检查 index 索引位置是否可以存在元素
func (list *MyArrayList) checkElementIndex(index int) error {
if !list.isElementIndex(index) {
return fmt.Errorf("Index: %d, Size: %d", index, list.size)
}
return nil
}
// 检查 index 索引位置是否可以添加元素
func (list *MyArrayList) checkPositionIndex(index int) error {
if !list.isPositionIndex(index) {
return fmt.Errorf("Index: %d, Size: %d", index, list.size)
}
return nil
}
func (list *MyArrayList) Display() {
fmt.Printf("size = %d cap = %d\n", list.size, len(list.data))
fmt.Println(list.data)
}
func main() {
// 初始容量设为 3
arr := NewMyArrayListWithCapacity(3)
// 添加 5 个元素
for i := 1; i <= 5; i++ {
arr.AddLast(i)
}
arr.Remove(3)
arr.Add(1, 9)
arr.AddFirst(100)
arr.RemoveLast()
// 100 1 9 2 3
for i := 0; i < arr.Size(); i++ {
val, _ := arr.Get(i)
fmt.Println(val)
}
}
class MyArrayList {
constructor(initCapacity) {
// 真正存储数据的底层数组
this.data = [];
// 记录当前元素个数
this.size = 0;
// 默认初始容量
this.INIT_CAP = 1;
// 初始化
this.init(initCapacity);
}
init(initCapacity) {
const capacity = initCapacity || this.INIT_CAP;
this.data = new Array(capacity);
this.size = 0;
}
// 增
addLast(e) {
const cap = this.data.length;
// 看 data 数组容量够不够
if (this.size === cap) {
this.resize(2 * cap);
}
// 在尾部插入元素
this.data[this.size] = e;
this.size++;
}
add(index, e) {
// 检查索引越界
this.checkPositionIndex(index);
const cap = this.data.length;
// 看 data 数组容量够不够
if (this.size === cap) {
this.resize(2 * cap);
}
// 搬移数据 data[index..] -> data[index+1..]
// 给新元素腾出位置
for (let i = this.size - 1; i >= index; i--) {
this.data[i + 1] = this.data[i];
}
// 插入新元素
this.data[index] = e;
this.size++;
}
addFirst(e) {
this.add(0, e);
}
// 删
removeLast() {
if (this.size === 0) {
throw new Error("NoSuchElementException");
}
const cap = this.data.length;
// 可以缩容,节约空间
if (this.size === Math.floor(cap / 4)) {
this.resize(Math.floor(cap / 2));
}
const deletedVal = this.data[this.size - 1];
// 删除最后一个元素
// 必须给最后一个元素置为 null,否则会内存泄漏
this.data[this.size - 1] = null;
this.size--;
return deletedVal;
}
remove(index) {
// 检查索引越界
this.checkElementIndex(index);
const cap = this.data.length;
// 可以缩容,节约空间
if (this.size === Math.floor(cap / 4)) {
this.resize(Math.floor(cap / 2));
}
const deletedVal = this.data[index];
// 搬移数据 data[index+1..] -> data[index..]
for (let i = index + 1; i < this.size; i++) {
this.data[i - 1] = this.data[i];
}
this.data[this.size - 1] = null;
this.size--;
return deletedVal;
}
removeFirst() {
return this.remove(0);
}
// 查
get(index) {
// 检查索引越界
this.checkElementIndex(index);
return this.data[index];
}
// 改
set(index, element) {
// 检查索引越界
this.checkElementIndex(index);
// 修改数据
const oldVal = this.data[index];
this.data[index] = element;
return oldVal;
}
// 工具方法
size() {
return this.size;
}
isEmpty() {
return this.size === 0;
}
// 将 data 的容量改为 newCap
resize(newCap) {
const temp = new Array(newCap);
for (let i = 0; i < this.size; i++) {
temp[i] = this.data[i];
}
this.data = temp;
}
isElementIndex(index) {
return index >= 0 && index < this.size;
}
isPositionIndex(index) {
return index >= 0 && index <= this.size;
}
// 检查 index 索引位置是否可以存在元素
checkElementIndex(index) {
if (!this.isElementIndex(index)) {
throw new Error("Index: " + index + ", Size: " + this.size);
}
}
// 检查 index 索引位置是否可以添加元素
checkPositionIndex(index) {
if (!this.isPositionIndex(index)) {
throw new Error("Index: " + index + ", Size: " + this.size);
}
}
display() {
console.log("size = " + this.size + " cap = " + this.data.length);
console.log(this.data);
}
}
// 初始容量设置为 3
const arr = new MyArrayList(3);
// 添加 5 个元素
for (let i = 1; i <= 5; i++) {
arr.addLast(i);
}
arr.remove(3);
arr.add(1, 9);
arr.addFirst(100);
const val = arr.removeLast();
// 100 1 9 2 3
for (let i = 0; i < arr.size; i++) {
console.log(arr.get(i));
}