博客 / 詳情

返回

SQLAlchemy中使用UPSERT

前言

SQLite 和 PostgreSQL 都支持 UPSERT 操作,即"有則更新,無則新增"。衝突列必須有唯一約束。

語法:

  • PostgreSQL: INSERT ... ON CONFLICT (column) DO UPDATE/NOTHING
  • SQLite: INSERT ... ON CONFLICT(column) DO UPDATE/NOTHING。注意括號位置
場景 PostgreSQL SQLite 説明
基本 UPSERT ON CONFLICT (col) DO UPDATE SET ... ON CONFLICT(col) DO UPDATE SET ... 括號位置略有不同
衝突忽略 ON CONFLICT (col) DO NOTHING ON CONFLICT(col) DO NOTHING 相同
引用新值 EXCLUDED.col excluded.col PostgreSQL 大寫,SQLite 小寫
返回結果 RETURNING * RETURNING * 相同
條件更新 WHERE condition 不支持 WHERE SQLite 限制

注意事項

  • 衝突列必須有唯一約束
  • PostgreSQL 和 SQLite 的語法相似,但仍有細微差別。使用原生 SQL 時需要注意。
  • SQLite 在 UPSERT 時不支持 WHERE 子句,需要改用 CASE 表達式或應用層過濾。
  • SQLite 3.35+ 版本才支持 RETURNING

EXCLUDED 和 RETURNING

EXCLUDED

EXCLUDED 表示衝突時被攔截的新值。

INSERT INTO users (email, name, age)
VALUES ('test@example.com', '新名字', 30)
ON CONFLICT (email) DO UPDATE SET
    name = EXCLUDED.name,   -- ← 引用新值 "新名字"
    age = EXCLUDED.age      -- ← 引用新值 30
場景 表達式 含義 示例值
原表字段 users.name 衝突行的當前值 "老名字"
新值字段 EXCLUDED.name 試圖插入的新值 "新名字"
混合計算 users.age + EXCLUDED.age 原值 + 新值 25 + 30 = 55

示例 1:累加庫存

-- 商品庫存累加:原庫存 100 + 新增 50 = 150
INSERT INTO products (sku, stock)
VALUES ('IPHONE15', 50)
ON CONFLICT (sku) DO UPDATE SET
    stock = products.stock + EXCLUDED.stock  -- 100 + 50
RETURNING stock;

示例 2:僅更新非空字段

-- 如果新值為 NULL,保留原值
INSERT INTO users (email, name, age)
VALUES ('test@example.com', '新名字', NULL)
ON CONFLICT (email) DO UPDATE SET
    name = COALESCE(EXCLUDED.name, users.name),  -- 新名字
    age = COALESCE(EXCLUDED.age, users.age)      -- 保留原 age

示例 3:時間戳更新

-- 更新時刷新 updated_at
INSERT INTO users (email, name)
VALUES ('test@example.com', '新名字')
ON CONFLICT (email) DO UPDATE SET
    name = EXCLUDED.name,
    updated_at = NOW()  -- PostgreSQL
    -- updated_at = CURRENT_TIMESTAMP  -- SQLite

RETURNING

RETURNING 用於返回操作結果。在 INSERT/UPDATE/DELETE直接返回指定列,避免額外 SELECT 查詢:

INSERT INTO users (email, name)
VALUES ('test@example.com', '張三')
RETURNING id, email, name, created_at;

示例 1:插入後立即獲取 ID

# PostgreSQL / SQLite 3.35+
sql = text("""
    INSERT INTO users (email, name)
    VALUES (:email, :name)
    RETURNING id, email, created_at
""")

result = await session.execute(sql, {"email": "test@example.com", "name": "張三"})
user = result.mappings().first()
print(user["id"])  # 直接獲取 ID

示例 2:UPSERT 後統一返回

-- 無論插入還是更新,都返回最終狀態
INSERT INTO users (email, name, login_count)
VALUES ('test@example.com', '張三', 1)
ON CONFLICT (email) DO UPDATE SET
    name = EXCLUDED.name,
    login_count = users.login_count + 1  -- 累加登錄次數
RETURNING 
    id,
    email,
    name,
    login_count,
    CASE 
        WHEN xmax = 0 THEN 'inserted'  -- PostgreSQL 特有:xmax=0 表示插入
        ELSE 'updated'
    END AS action

示例 3:批量操作返回所有結果

-- PostgreSQL 支持批量 RETURNING
INSERT INTO users (email, name)
VALUES 
    ('a@example.com', 'A'),
    ('b@example.com', 'B')
ON CONFLICT (email) DO UPDATE SET
    name = EXCLUDED.name
RETURNING id, email, name;

Python 處理批量返回:

result = await session.execute(sql)
users = [dict(row) for row in result.mappings().all()]
# [{'id': 1, 'email': 'a@example.com', 'name': 'A'}, ...]

示例:用户登錄計數器

async def record_user_login(session: AsyncSession, email: str, name: str) -> dict:
    """
    用户登錄計數器:
    - 新用户:插入,login_count = 1
    - 老用户:更新,login_count += 1
    - 返回最終狀態 + 操作類型
    """
    sql = text("""
        INSERT INTO users (
            email, name, login_count, last_login, created_at
        ) VALUES (
            :email, :name, 1, :now, :now
        )
        ON CONFLICT (email) DO UPDATE SET
            name = EXCLUDED.name,                          -- 更新用户名
            login_count = users.login_count + 1,           -- 累加登錄次數
            last_login = EXCLUDED.last_login               -- 更新最後登錄時間
        RETURNING
            id,
            email,
            name,
            login_count,
            last_login,
            created_at,
            CASE 
                WHEN xmax = 0 THEN 'inserted' 
                ELSE 'updated' 
            END AS action  -- PostgreSQL 特有:區分插入/更新
    """)
    
    now = datetime.utcnow()
    result = await session.execute(
        sql,
        {"email": email, "name": name, "now": now}
    )
    
    row = result.mappings().first()
    return dict(row) if row else None

# 使用示例
user = await record_user_login(session, "test@example.com", "張三")
print(f"{user['action']} user {user['email']} with {user['login_count']} logins")
# 輸出: inserted user test@example.com with 1 logins
# 或: updated user test@example.com with 5 logins

示例數據模型類

from sqlalchemy import Column, Integer, String, UniqueConstraint
from sqlalchemy.orm import DeclarativeBase

class Base(DeclarativeBase):
    pass

class User(Base):
    __tablename__ = "users"
    
    id = Column(Integer, primary_key=True, autoincrement=True)
    email = Column(String(100), unique=True, nullable=False)  # 唯一約束
    name = Column(String(50))
    age = Column(Integer)
    balance = Column(Integer, default=0)
    
    __table_args__ = (
        UniqueConstraint("email", name="uq_users_email"),
    )

class Product(Base):
    __tablename__ = "products"
    
    id = Column(Integer, primary_key=True)
    sku = Column(String(50), unique=True, nullable=False)  # 唯一 SKU
    name = Column(String(100))
    stock = Column(Integer, default=0)
    price = Column(Integer)

ORM 方式

注意 insert 的導入路徑。

基本示例

from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy import insert

async def upsert_user_orm(session: AsyncSession, user_data: dict) -> dict:
    """
    UPSERT 用户(ORM 風格)
    如果 email 衝突則更新,否則插入
    """
    
    # 方式 1:使用通用 insert(推薦⭐)
    # SQLAlchemy 會根據方言自動選擇正確的語法
    stmt = (
        insert(User)
        .values(**user_data)
        .on_conflict_do_update(
            index_elements=["email"],  # 衝突檢測列(唯一約束)
            set_={
                "name": user_data["name"],
                "age": user_data.get("age"),
                "updated_at": func.now()  # 假設有 updated_at 列
            }
        )
        .returning(User)  # 返回插入/更新後的行
    )
    
    result = await session.execute(stmt)
    user = result.scalar_one()
    
    return {
        "id": user.id,
        "email": user.email,
        "name": user.name,
        "age": user.age
    }

async def upsert_user_ignore(session: AsyncSession, user_data: dict) -> bool:
    """
    UPSERT 但衝突時忽略(DO NOTHING)
    """
    stmt = (
        insert(User)
        .values(**user_data)
        .on_conflict_do_nothing(
            index_elements=["email"]
        )
    )
    
    result = await session.execute(stmt)
    return result.rowcount > 0  # 返回是否插入成功

條件更新:僅更新特定字段

async def upsert_user_conditional(session: AsyncSession, user_data: dict) -> dict:
    """
    UPSERT:衝突時只更新非空字段
    """
    stmt = (
        insert(User)
        .values(**user_data)
        .on_conflict_do_update(
            index_elements=["email"],
            set_={
                "name": user_data["name"],
                # 條件:只有提供了 age 才更新
                "age": user_data.get("age", User.age),  # 保持原值
            },
            # 可選:添加 WHERE 條件
            where=User.email == user_data["email"]
        )
        .returning(User)
    )
    
    result = await session.execute(stmt)
    return result.mappings().first()

批量 UPSERT

async def bulk_upsert_users(session: AsyncSession, users: list[dict]) -> int:
    """
    批量 UPSERT 用户
    """
    stmt = (
        insert(User)
        .values(users)
        .on_conflict_do_update(
            index_elements=["email"],
            set_={
                "name": insert(User).excluded.name,  # 使用 excluded 表示新值
                "age": insert(User).excluded.age,
            }
        )
    )
    
    result = await session.execute(stmt)
    return result.rowcount

使用 EXCLUDED 引用新值

async def upsert_product_with_stock(session: AsyncSession, product_data: dict) -> dict:
    """
    UPSERT 產品:衝突時累加庫存
    """
    stmt = (
        insert(Product)
        .values(**product_data)
        .on_conflict_do_update(
            index_elements=["sku"],
            set_={
                # 累加庫存:原庫存 + 新庫存
                "stock": Product.stock + insert(Product).excluded.stock,
                # 更新其他字段
                "name": insert(Product).excluded.name,
                "price": insert(Product).excluded.price,
            }
        )
        .returning(Product)
    )
    
    result = await session.execute(stmt)
    return result.mappings().first()

用户服務

class UserService:
    """用户服務(支持 UPSERT)"""
    
    def __init__(self, session: AsyncSession):
        self.session = session
    
    async def create_or_update(self, email: str, name: str, age: int | None = None) -> dict:
        """創建或更新用户"""
        stmt = (
            insert(User)
            .values(
                email=email,
                name=name,
                age=age,
                created_at=datetime.utcnow()
            )
            .on_conflict_do_update(
                index_elements=["email"],
                set_={
                    "name": name,
                    "age": age,
                    "updated_at": datetime.utcnow()
                }
            )
            .returning(User)
        )
        
        result = await self.session.execute(stmt)
        user = result.scalar_one()
        
        return {
            "id": user.id,
            "email": user.email,
            "name": user.name,
            "age": user.age
        }
    
    async def bulk_create_or_update(self, users: list[dict]) -> int:
        """批量創建或更新"""
        stmt = (
            insert(User)
            .values(users)
            .on_conflict_do_update(
                index_elements=["email"],
                set_={
                    "name": insert(User).excluded.name,
                    "age": insert(User).excluded.age,
                    "updated_at": datetime.utcnow()
                }
            )
        )
        
        result = await self.session.execute(stmt)
        return result.rowcount
    
    async def create_if_not_exists(self, email: str, name: str) -> bool:
        """僅當不存在時創建"""
        stmt = (
            insert(User)
            .values(
                email=email,
                name=name,
                created_at=datetime.utcnow()
            )
            .on_conflict_do_nothing(
                index_elements=["email"]
            )
        )
        
        result = await self.session.execute(stmt)
        return result.rowcount > 0  # True = 插入成功,False = 已存在

原生 SQL

基本示例

PostgreSQL

async def upsert_user_pg(session: AsyncSession, user_data: dict) -> dict | None:
    """
    PostgreSQL 原生 UPSERT
    """
    sql = text("""
        INSERT INTO users (email, name, age, created_at)
        VALUES (:email, :name, :age, :created_at)
        ON CONFLICT (email) DO UPDATE  -- 衝突列
        SET 
            name = EXCLUDED.name,      -- EXCLUDED 表示新插入的值
            age = EXCLUDED.age,
            updated_at = NOW()
        RETURNING id, email, name, age
    """)
    
    result = await session.execute(
        sql,
        {
            "email": user_data["email"],
            "name": user_data["name"],
            "age": user_data.get("age"),
            "created_at": datetime.utcnow()
        }
    )
    
    row = result.mappings().first()
    return dict(row) if row else None

SQLite

async def upsert_user_sqlite(session: AsyncSession, user_data: dict) -> dict | None:
    """
    SQLite 原生 UPSERT(語法與 PostgreSQL 幾乎相同)
    """
    sql = text("""
        INSERT INTO users (email, name, age, created_at)
        VALUES (:email, :name, :age, :created_at)
        ON CONFLICT(email) DO UPDATE SET  -- SQLite 語法稍有不同
            name = excluded.name,
            age = excluded.age,
            updated_at = CURRENT_TIMESTAMP
        RETURNING id, email, name, age
    """)
    
    result = await session.execute(
        sql,
        {
            "email": user_data["email"],
            "name": user_data["name"],
            "age": user_data.get("age"),
            "created_at": datetime.utcnow()
        }
    )
    
    row = result.mappings().first()
    return dict(row) if row else None

衝突時忽略

async def insert_or_ignore_user(session: AsyncSession, user_data: dict) -> bool:
    """
    插入用户,如果衝突則忽略
    """
    # PostgreSQL
    sql = text("""
        INSERT INTO users (email, name, age, created_at)
        VALUES (:email, :name, :age, :created_at)
        ON CONFLICT (email) DO NOTHING
    """)
    
    # SQLite(語法相同)
    # sql = text("""
    #     INSERT INTO users (email, name, age, created_at)
    #     VALUES (:email, :name, :age, :created_at)
    #     ON CONFLICT(email) DO NOTHING
    # """)
    
    result = await session.execute(
        sql,
        {
            "email": user_data["email"],
            "name": user_data["name"],
            "age": user_data.get("age"),
            "created_at": datetime.utcnow()
        }
    )
    
    return result.rowcount > 0  # 返回是否插入成功

批量 UPSERT

async def bulk_upsert_products(session: AsyncSession, products: list[dict]) -> int:
    """
    批量 UPSERT 產品(原生 SQL)
    """
    # PostgreSQL
    sql = text("""
        INSERT INTO products (sku, name, stock, price, created_at)
        VALUES (
            :sku, :name, :stock, :price, :created_at
        )
        ON CONFLICT (sku) DO UPDATE SET
            name = EXCLUDED.name,
            stock = products.stock + EXCLUDED.stock,  -- 累加庫存
            price = EXCLUDED.price,
            updated_at = NOW()
    """)
    
    # 批量執行
    for product in products:
        await session.execute(
            sql,
            {
                "sku": product["sku"],
                "name": product["name"],
                "stock": product.get("stock", 0),
                "price": product.get("price", 0),
                "created_at": datetime.utcnow()
            }
        )
    
    return len(products)

部分更新 + 條件判斷

async def upsert_user_smart(session: AsyncSession, user_data: dict) -> dict | None:
    """
    智能 UPSERT:
    - 如果提供了 age,才更新 age
    - 如果提供了 name,才更新 name
    - 更新 updated_at
    """
    sql = text("""
        INSERT INTO users (email, name, age, created_at)
        VALUES (:email, :name, :age, :created_at)
        ON CONFLICT (email) DO UPDATE SET
            name = COALESCE(:name, users.name),  -- 如果新值為 NULL,保持原值
            age = COALESCE(:age, users.age),
            updated_at = NOW()
        RETURNING id, email, name, age, updated_at
    """)
    
    result = await session.execute(
        sql,
        {
            "email": user_data["email"],
            "name": user_data.get("name"),  # 可能為 None
            "age": user_data.get("age"),    # 可能為 None
            "created_at": datetime.utcnow()
        }
    )
    
    row = result.mappings().first()
    return dict(row) if row else None

用户註冊/登錄:存在則更新最後登錄時間

async def register_or_login(session: AsyncSession, email: str, name: str) -> dict:
    """
    用户註冊或登錄:
    - 新用户:插入
    - 老用户:更新最後登錄時間
    """
    sql = text("""
        INSERT INTO users (email, name, last_login, created_at)
        VALUES (:email, :name, :now, :now)
        ON CONFLICT (email) DO UPDATE SET
            last_login = EXCLUDED.last_login,
            name = EXCLUDED.name  -- 可選:更新用户名
        RETURNING id, email, name, last_login, created_at
    """)
    
    now = datetime.utcnow()
    result = await session.execute(
        sql,
        {"email": email, "name": name, "now": now}
    )
    
    return dict(result.mappings().first())

庫存累加

async def add_product_stock(session: AsyncSession, sku: str, quantity: int) -> bool:
    """
    增加商品庫存:
    - 商品不存在:插入
    - 商品存在:累加庫存
    """
    sql = text("""
        INSERT INTO products (sku, stock, created_at)
        VALUES (:sku, :quantity, :now)
        ON CONFLICT (sku) DO UPDATE SET
            stock = products.stock + EXCLUDED.stock,
            updated_at = NOW()
    """)
    
    result = await session.execute(
        sql,
        {
            "sku": sku,
            "quantity": quantity,
            "now": datetime.utcnow()
        }
    )
    
    return result.rowcount > 0

用户積分累加

async def add_user_points(session: AsyncSession, user_id: int, points: int) -> dict | None:
    """
    增加用户積分(累加)
    """
    sql = text("""
        INSERT INTO user_points (user_id, points, created_at)
        VALUES (:user_id, :points, :now)
        ON CONFLICT (user_id) DO UPDATE SET
            points = user_points.points + EXCLUDED.points,
            updated_at = NOW()
        RETURNING user_id, points
    """)
    
    result = await session.execute(
        sql,
        {
            "user_id": user_id,
            "points": points,
            "now": datetime.utcnow()
        }
    )
    
    row = result.mappings().first()
    return dict(row) if row else None

標籤計數

存在則 +1,不存在則創建:

async def increment_tag_count(session: AsyncSession, tag_name: str) -> int:
    """
    標籤計數:
    - 標籤不存在:插入 count=1
    - 標籤存在:count += 1
    """
    sql = text("""
        INSERT INTO tags (name, count, created_at)
        VALUES (:name, 1, :now)
        ON CONFLICT (name) DO UPDATE SET
            count = tags.count + 1,
            updated_at = NOW()
        RETURNING count
    """)
    
    result = await session.execute(
        sql,
        {"name": tag_name, "now": datetime.utcnow()}
    )
    
    return result.scalar() or 0
user avatar
0 位用戶收藏了這個故事!

發佈 評論

Some HTML is okay.