一、鏈表的基本原理

刷過力扣的讀者肯定對單鏈表非常熟悉,力扣上的單鏈表節點定義如下:

private:
    template <typename E>
    class Node {
    public:
        E val;
        Node* next;
        Node* prev;

        Node(Node* prev, E element, Node* next) {
            this->val = element;
            this->next = next;
            this->prev = prev;
        }
    };:

實際的定義:

private:
    template <typename E>
    class Node {
    public:
        E val;
        Node* next;
        Node* prev;

        Node(Node* prev, E element, Node* next) {
            this->val = element;
            this->next = next;
            this->prev = prev;
        }
    };

主要區別有兩個:

  • 1、編程語言標準庫一般都會提供泛型,即你可以指定 val 字段為任意類型,而力扣的單鏈表節點的 val 字段只有 int 類型。
  • 2、編程語言標準庫一般使用的都是雙鏈表而非單鏈表。單鏈表節點只有一個 next 指針,指向下一個節點;而雙鏈表節點有兩個指針,prev 指向前一個節點,next 指向下一個節點。

有了 prev 前驅指針,鏈表支持雙向遍歷,但由於要多維護一個指針,增刪查改時會稍微複雜一些,後面帶大家實現雙鏈表時會具體介紹。

為什麼需要鏈表

主要區別有兩個:

1、編程語言標準庫一般都會提供泛型,即你可以指定 val 字段為任意類型,而力扣的單鏈表節點的 val 字段只有 int 類型。

2、編程語言標準庫一般使用的都是雙鏈表而非單鏈表。單鏈表節點只有一個 next 指針,指向下一個節點;而雙鏈表節點有兩個指針,prev 指向前一個節點,next 指向下一個節點。

有了 prev 前驅指針,鏈表支持雙向遍歷,但由於要多維護一個指針,增刪查改時會稍微複雜一些,後面帶大家實現雙鏈表時會具體介紹

二、單鏈表的基本操作

class ListNode {
public:
    int val;
    ListNode *next;
    ListNode(int x) : val(x), next(NULL) {}
};

// 輸入一個數組,轉換為一條單鏈表
ListNode* createLinkedList(std::vector<int> arr) {
    if (arr.empty()) {
        return nullptr;
    }
    ListNode* head = new ListNode(arr[0]);
    ListNode* cur = head;
    for (int i = 1; i < arr.size(); i++) {
        cur->next = new ListNode(arr[i]);
        cur = cur->next;
    }
    return head;
}

1.查/改(單鏈表的遍歷/查找/修改)

// 創建一條單鏈表
ListNode* head = createLinkedList({1, 2, 3, 4, 5});

// 遍歷單鏈表
for (ListNode* p = head; p != nullptr; p = p->next) {
    std::cout << p->val << std::endl;
}

2.增

(1)在單鏈表頭部插入新元素

// 創建一條單鏈表
ListNode* head = createLinkedList({1, 2, 3, 4, 5});

// 在單鏈表頭部插入一個新節點 0
ListNode* newNode = new ListNode(0);
newNode->next = head;
head = newNode;

// 現在鏈表變成了 0 -> 1 -> 2 -> 3 -> 4 -> 5

(2)在單鏈表尾部插入新元素

// 創建一條單鏈表
ListNode* head = createLinkedList({1, 2, 3, 4, 5});

// 在單鏈表尾部插入一個新節點 6
ListNode* p = head;
// 先走到鏈表的最後一個節點
while (p->next != nullptr) {
    p = p->next;
}
// 現在 p 就是鏈表的最後一個節點
// 在 p 後面插入新節點
p->next = new ListNode(6);

// 現在鏈表變成了 1 -> 2 -> 3 -> 4 -> 5 -> 6

(3)在單鏈表中間插入新元素

// 創建一條單鏈表
ListNode* head = createLinkedList({1, 2, 3, 4, 5});

// 在第 3 個節點後面插入一個新節點 66
// 先要找到前驅節點,即第 3 個節點
ListNode* p = head;
for (int i = 0; i < 2; i++) {
    p = p->next;
}
// 此時 p 指向第 3 個節點
// 組裝新節點的後驅指針
ListNode* newNode = new ListNode(66);
newNode->next = p->next;

// 插入新節點
p->next = newNode;

// 現在鏈表變成了 1 -> 2 -> 3 -> 66 -> 4 -> 5

3.刪

(1)在單鏈表中刪除一個節點

刪除一個節點,首先要找到要被刪除節點的前驅節點,然後把這個前驅節點的 next 指針指向被刪除節點的下一個節點。這樣就能把被刪除節點從鏈表中摘除了。

// 創建一條單鏈表
ListNode* head = createLinkedList({1, 2, 3, 4, 5});

// 刪除第 4 個節點,要操作前驅節點
ListNode* p = head;
for (int i = 0; i < 2; i++) {
    p = p->next;
}

// 此時 p 指向第 3 個節點,即要刪除節點的前驅節點
// 把第 4 個節點從鏈表中摘除
p->next = p->next->next;

// 現在鏈表變成了 1 -> 2 -> 3 -> 5

(2)在單鏈表尾部刪除元素

這個操作比較簡單,找到倒數第二個節點,然後把它的 next 指針置為 null 就行了:

// 創建一條單鏈表
ListNode* head = createLinkedList({1, 2, 3, 4, 5});

// 刪除尾節點
ListNode* p = head;
// 找到倒數第二個節點
while (p->next->next != nullptr) {
    p = p->next;
}

// 此時 p 指向倒數第二個節點
// 把尾節點從鏈表中摘除
p->next = nullptr;

// 現在鏈表變成了 1 -> 2 -> 3 -> 4

(3)在單鏈表頭部刪除元素

// 創建一條單鏈表
ListNode* head = createLinkedList(vector<int>{1, 2, 3, 4, 5});

// 刪除頭結點
head = head->next;

// 現在鏈表變成了 2 -> 3 -> 4 -> 5

三、雙鏈表的基本操作

class DoublyListNode {
public:
    int val;
    DoublyListNode *next, *prev;
    DoublyListNode(int x) : val(x), next(NULL), prev(NULL) {}
};

DoublyListNode* createDoublyLinkedList(const vector<int>& arr) {
    if (arr.empty()) {
        return NULL;
    }
    DoublyListNode* head = new DoublyListNode(arr[0]);
    DoublyListNode* cur = head;
    // for 循環迭代創建雙鏈表
    for (int i = 1; i < arr.size(); i++) {
        DoublyListNode* newNode = new DoublyListNode(arr[i]);
        cur->next = newNode;
        newNode->prev = cur;
        cur = cur->next;
    }
    return head;
}

1.查/改(雙鏈表的遍歷/查找/修改)

// 創建一條雙鏈表
DoublyListNode* head = createDoublyLinkedList({1, 2, 3, 4, 5});
DoublyListNode* tail = nullptr;

// 從頭節點向後遍歷雙鏈表
for (DoublyListNode* p = head; p != nullptr; p = p->next) {
    cout << p->val << endl;
    tail = p;
}

// 從尾節點向前遍歷雙鏈表
for (DoublyListNode* p = tail; p != nullptr; p = p->prev) {
    cout << p->val << endl;
}

訪問或修改節點時,可以根據索引是靠近頭部還是尾部,選擇合適的方向遍歷,這樣可以一定程度上提高效率。

2.增

(1)在雙鏈表頭部插入新元素

// 創建一條雙鏈表
DoublyListNode* head = createDoublyLinkedList({1, 2, 3, 4, 5});

// 在雙鏈表頭部插入新節點 0
DoublyListNode* newHead = new DoublyListNode(0);
newHead->next = head;
head->prev = newHead;
head = newHead;

// 現在鏈表變成了 0 -> 1 -> 2 -> 3 -> 4 -> 5

(2)在雙鏈表尾部插入新元素

// 創建一條雙鏈表
DoublyListNode* head = createDoublyLinkedList({1, 2, 3, 4, 5});

DoublyListNode* tail = head;
// 先走到鏈表的最後一個節點
while (tail->next != nullptr) {
    tail = tail->next;
}

// 在雙鏈表尾部插入新節點 6
DoublyListNode* newNode = new DoublyListNode(6);
tail->next = newNode;
newNode->prev = tail;
// 更新尾節點引用
tail = newNode;

// 現在鏈表變成了 1 -> 2 -> 3 -> 4 -> 5 -> 6

(3)在雙鏈表中間插入新元素

// 創建一條雙鏈表
DoublyListNode* head = createDoublyLinkedList({1, 2, 3, 4, 5});

// 想要插入到索引 3(第 4 個節點)
// 需要操作索引 2(第 3 個節點)的指針
DoublyListNode* p = head;
for (int i = 0; i < 2; i++) {
    p = p->next;
}

// 組裝新節點
DoublyListNode* newNode = new DoublyListNode(66);
newNode->next = p->next;
newNode->prev = p;

// 插入新節點
p->next->prev = newNode;
p->next = newNode;

// 現在鏈表變成了 1 -> 2 -> 3 -> 66 -> 4 -> 5

3.刪

(1)在雙鏈表中刪除一個節點

// 創建一個雙鏈表
DoublyListNode* head = createDoublyLinkedList({1, 2, 3, 4, 5});

// 刪除第 4 個節點
// 先找到第 3 個節點
DoublyListNode* p = head;
for (int i = 0; i < 2; ++i) {
    p = p->next;
}

// 現在 p 指向第 3 個節點,我們將它後面那個節點摘除出去
DoublyListNode* toDelete = p->next;

// 把 toDelete 從鏈表中摘除
p->next = toDelete->next;
toDelete->next->prev = p;

// 把 toDelete 的前後指針都置為 null 是個好習慣(可選)
toDelete->next = nullptr;
toDelete->prev = nullptr;

// 現在鏈表變成了 1 -> 2 -> 3 -> 5

(2)在雙鏈表頭刪除一個結點

// 創建一條雙鏈表
DoublyListNode* head = createDoublyLinkedList({1, 2, 3, 4, 5});

// 刪除頭結點
DoublyListNode* toDelete = head;
head = head->next;
head->prev = nullptr;

// 清理已刪除節點的指針
toDelete->next = nullptr;

// 現在鏈表變成了 2 -> 3 -> 4 -> 5

(3)在雙鏈表尾部刪除元素

// 創建一條雙鏈表
DoublyListNode* head = createDoublyLinkedList({1, 2, 3, 4, 5});

// 刪除尾節點
DoublyListNode* p = head;
// 找到尾結點
while (p->next != nullptr) {
    p = p->next;
}

// 現在 p 指向尾節點
// 把尾節點從鏈表中摘除
p->prev->next = nullptr;

// 把被刪結點的指針都斷開是個好習慣(可選)
p->prev = nullptr;

// 現在鏈表變成了 1 -> 2 -> 3 -> 4

四、關鍵點

1.同時持有頭尾節點的引用

在力扣做題時,一般題目給我們傳入的就是單鏈表的頭指針。但是在實際開發中,用的都是雙鏈表,而雙鏈表一般會同時持有頭尾節點的引用。

因為在軟件開發中,在容器尾部添加元素是個非常高頻的操作,雙鏈表持有尾部節點的引用,就可以在 O(1) 的時間複雜度內完成尾部添加元素的操作。

對於單鏈表來説,持有尾部節點的引用也有優化效果。比如你要在單鏈表尾部添加元素,如果沒有尾部節點的引用,你就需要遍歷整個鏈表找到尾部節點,時間複雜度是
O(n);如果有尾部節點的引用,就可以在 O(1) 的時間複雜度內完成尾部添加元素的操作。

即便如此,如果刪除一次單鏈表的尾結點,那麼之前尾結點的引用就失效了,還是需要遍歷一遍鏈表找到尾結點。

是的,但你再仔細想想,刪除單鏈表尾結點的時候,是不是也得遍歷到倒數第二個節點(尾結點的前驅),才能通過指針操作把尾結點刪掉?那麼這個時候,你不就可以順便把尾結點的引用給更新了嗎

2.虛擬頭尾節點舉例來説,假設虛擬頭尾節點分別是 dummyHead 和 dummyTail,那麼一條空的雙鏈表長這樣:

dummyHead <-> dummyTail

如果你添加 1,2,3 幾個元素,那麼鏈表長這樣:

dummyHead <-> 1 <-> 2 <-> 3 <-> dummyTail

你以前要把在頭部插入元素、在尾部插入元素和在中間插入元素幾種情況分開討論,現在有了頭尾虛擬節點,無論鏈表是否為空,都只需要考慮在中間插入元素的情況就可以了,這樣代碼會簡潔很多。

當然,虛擬頭結點會多佔用一點內存空間,但是比起給你解決的麻煩,這點空間消耗是划算的。

對於單鏈表,虛擬頭結點有一定的簡化作用,但虛擬尾節點沒有太大作用。

虛擬節點是內部實現,對外不可見

虛擬節點是你內部實現數據結構的技巧,對外是不可見的。比如按照索引獲取元素的 get(index) 方法,都是從真實節點開始計算索引,而不是從虛擬節點開始計算

五、代碼實現

單鏈表

#include <iostream>
#include <stdexcept>

template <typename E>
class MyLinkedList2 {
private:
    // 節點結構
    struct Node {
        E val;
        Node* next;

        Node(E value) : val(value), next(nullptr) {}
    };

    Node* head;
    // 實際的尾部節點引用
    Node* tail;
    int size_;

public:
    MyLinkedList2() {
        head = new Node(E());
        tail = head;
        size_ = 0;
    }

    ~MyLinkedList2() {
        Node* current = head;
        while (current != nullptr) {
            Node* next = current->next;
            delete current;
            current = next;
        }
    }

    void addFirst(E e) {
        Node* newNode = new Node(e);
        newNode->next = head->next;
        head->next = newNode;
        if (size_ == 0) {
            tail = newNode;
        }
        size_++;
    }

    void addLast(E e) {
        Node* newNode = new Node(e);
        tail->next = newNode;
        tail = newNode;
        size_++;
    }

    void add(int index, E element) {
        checkPositionIndex(index);

        if (index == size_) {
            addLast(element);
            return;
        }

        Node* prev = head;
        for (int i = 0; i < index; i++) {
            prev = prev->next;
        }
        Node* newNode = new Node(element);
        newNode->next = prev->next;
        prev->next = newNode;
        size_++;
    }

    E removeFirst() {
        if (isEmpty()) {
            throw std::out_of_range("No elements to remove");
        }
        Node* first = head->next;
        head->next = first->next;
        if (size_ == 1) {
            tail = head;
        }
        size_--;
        E val = first->val;
        delete first;
        return val;
    }

    E removeLast() {
        if (isEmpty()) {
            throw std::out_of_range("No elements to remove");
        }

        Node* prev = head;
        while (prev->next != tail) {
            prev = prev->next;
        }
        E val = tail->val;
        delete tail;
        prev->next = nullptr;
        tail = prev;
        size_--;
        return val;
    }

    E remove(int index) {
        checkElementIndex(index);

        Node* prev = head;
        for (int i = 0; i < index; i++) {
            prev = prev->next;
        }

        Node* nodeToRemove = prev->next;
        prev->next = nodeToRemove->next;
        // 刪除的是最後一個元素
        if (index == size_ - 1) {
            tail = prev;
        }
        size_--;
        E val = nodeToRemove->val;
        delete nodeToRemove;
        return val;
    }

    // ***** 查 *****

    E getFirst() {
        if (isEmpty()) {
            throw std::out_of_range("No elements in the list");
        }
        return head->next->val;
    }

    E getLast() {
        if (isEmpty()) {
            throw std::out_of_range("No elements in the list");
        }
        return tail->val;
    }

    E get(int index) {
        checkElementIndex(index);
        Node* p = getNode(index);
        return p->val;
    }

    // ***** 改 *****

    E set(int index, E element) {
        checkElementIndex(index);
        Node* p = getNode(index);

        E oldVal = p->val;
        p->val = element;

        return oldVal;
    }

    // ***** 其他工具函數 *****
    int size() {
        return size_;
    }

    bool isEmpty() {
        return size_ == 0;
    }

private:
    bool isElementIndex(int index) {
        return index >= 0 && index < size_;
    }

    bool isPositionIndex(int index) {
        return index >= 0 && index <= size_;
    }

    // 檢查 index 索引位置是否可以存在元素
    void checkElementIndex(int index) {
        if (!isElementIndex(index)) {
            throw std::out_of_range("Index: " + std::to_string(index) + ", size_: " + std::to_string(size_));
        }
    }

    // 檢查 index 索引位置是否可以添加元素
    void checkPositionIndex(int index) {
        if (!isPositionIndex(index)) {
            throw std::out_of_range("Index: " + std::to_string(index) + ", size_: " + std::to_string(size_));
        }
    }

    // 返回 index 對應的 Node
    // 注意:請保證傳入的 index 是合法的
    Node* getNode(int index) {
        Node* p = head->next;
        for (int i = 0; i < index; i++) {
            p = p->next;
        }
        return p;
    }
};

int main() {
    MyLinkedList2<int> list;
    list.addFirst(1);
    list.addFirst(2);
    list.addLast(3);
    list.addLast(4);
    list.add(2, 5);

    std::cout << list.removeFirst() << std::endl; // 2
    std::cout << list.removeLast() << std::endl; // 4
    std::cout << list.remove(1) << std::endl; // 5

    std::cout << list.getFirst() << std::endl; // 1
    std::cout << list.getLast() << std::endl; // 3
    std::cout << list.get(1) << std::endl; // 3

    return 0;
}

雙鏈表

#include <iostream>
#include <stdexcept>

template<typename E>
class MyLinkedList {
    // 虛擬頭尾節點
    struct Node {
        E val;
        Node* next;
        Node* prev;

        Node(E value) : val(value), next(nullptr), prev(nullptr) {}
    };

    Node* head;
    Node* tail;
    int size;

public:
    // 構造函數初始化虛擬頭尾節點
    MyLinkedList() {
        head = new Node(E());
        tail = new Node(E());
        head->next = tail;
        tail->prev = head;
        size = 0;
    }

    ~MyLinkedList() {
        while (size > 0) {
            removeFirst();
        }
        delete head;
        delete tail;
    }

    // ***** 增 *****

    void addLast(E e) {
        Node* x = new Node(e);
        Node* temp = tail->prev;

        temp->next = x;
        x->prev = temp;
        // temp <-> x

        x->next = tail;
        tail->prev = x;
        // temp <-> x <-> tail
        size++;
    }

    void addFirst(E e) {
        Node* x = new Node(e);
        Node* temp = head->next;
        // head <-> temp
        temp->prev = x;
        x->next = temp;

        head->next = x;
        x->prev = head;
        // head <-> x <-> temp
        size++;
    }

    void add(int index, E element) {
        checkPositionIndex(index);
        if (index == size) {
            addLast(element);
            return;
        }

        // 找到 index 對應的 Node
        Node* p = getNode(index);
        Node* temp = p->prev;
        // temp <-> p

        // 新要插入的 Node
        Node* x = new Node(element);

        p->prev = x;
        temp->next = x;

        x->prev = temp;
        x->next = p;

        // temp <-> x <-> p

        size++;
    }

    // ***** 刪 *****

    E removeFirst() {
        if (size < 1) {
            throw std::out_of_range("No elements to remove");
        }
        // 虛擬節點的存在是我們不用考慮空指針的問題
        Node* x = head->next;
        Node* temp = x->next;
        // head <-> x <-> temp
        head->next = temp;
        temp->prev = head;

        E val = x->val;
        delete x;
        // head <-> temp

        size--;
        return val;
    }

    E removeLast() {
        if (size < 1) {
            throw std::out_of_range("No elements to remove");
        }
        Node* x = tail->prev;
        Node* temp = tail->prev->prev;
        // temp <-> x <-> tail

        tail->prev = temp;
        temp->next = tail;

        E val = x->val;
        x->prev = nullptr;
        x->next = nullptr;
        delete x;
        // temp <-> tail

        size--;
        return val;
    }

    E remove(int index) {
        checkElementIndex(index);
        // 找到 index 對應的 Node
        Node* x = getNode(index);
        Node* prev = x->prev;
        Node* next = x->next;
        // prev <-> x <-> next
        prev->next = next;
        next->prev = prev;

        E val = x->val;
        x->prev = nullptr;
        x->next = nullptr;
        delete x;
        // prev <-> next

        size--;
        return val;
    }

    // ***** 查 *****

    E get(int index) {
        checkElementIndex(index);
        // 找到 index 對應的 Node
        Node* p = getNode(index);

        return p->val;
    }

    E getFirst() {
        if (size < 1) {
            throw std::out_of_range("No elements in the list");
        }

        return head->next->val;
    }

    E getLast() {
        if (size < 1) {
            throw std::out_of_range("No elements in the list");
        }

        return tail->prev->val;
    }

    // ***** 改 *****

    E set(int index, E val) {
        checkElementIndex(index);
        // 找到 index 對應的 Node
        Node* p = getNode(index);

        E oldVal = p->val;
        p->val = val;

        return oldVal;
    }

    // ***** 其他工具函數 *****

    int getSize() const {
        return size;
    }

    bool isEmpty() const {
        return size == 0;
    }

    void display() {
        std::cout << "size = " << size << std::endl;
        for (Node* p = head->next; p != tail; p = p->next) {
            std::cout << p->val << " <-> ";
        }
        std::cout << "nullptr" << std::endl;
        std::cout << std::endl;
    }

private:
    Node* getNode(int index) {
        checkElementIndex(index);
        Node* p = head->next;
        // TODO: 可以優化,通過 index 判斷從 head 還是 tail 開始遍歷
        for (int i = 0; i < index; i++) {
            p = p->next;
        }
        return p;
    }

    bool isElementIndex(int index) const {
        return index >= 0 && index < size;
    }

    bool isPositionIndex(int index) const {
        return index >= 0 && index <= size;
    }

    // 檢查 index 索引位置是否可以存在元素
    void checkElementIndex(int index) const {
        if (!isElementIndex(index))
            throw std::out_of_range("Index: " + std::to_string(index) + ", Size: " + std::to_string(size));
    }

    // 檢查 index 索引位置是否可以添加元素
    void checkPositionIndex(int index) const {
        if (!isPositionIndex(index))
            throw std::out_of_range("Index: " + std::to_string(index) + ", Size: " + std::to_string(size));
    }
};

int main() {
    MyLinkedList<int> list;
    list.addLast(1);
    list.addLast(2);
    list.addLast(3);
    list.addFirst(0);
    list.add(2, 100);

    list.display();
    // size = 5
    // 0 <-> 1 <-> 100 <-> 2 <-> 3 <-> null

    return 0;
}