找回密码
 立即注册
首页 业界区 业界 SQLAlchemy中使用UPSERT

SQLAlchemy中使用UPSERT

侧胥咽 3 小时前
前言

SQLite 和 PostgreSQL 都支持 UPSERT 操作,即"有则更新,无则新增"。冲突列必须有唯一约束。
语法:

  • PostgreSQL: INSERT ... ON CONFLICT (column) DO UPDATE/NOTHING
  • SQLite: INSERT ... ON CONFLICT(column) DO UPDATE/NOTHING。注意括号位置
场景PostgreSQLSQLite说明基本 UPSERTON CONFLICT (col) DO UPDATE SET ...ON CONFLICT(col) DO UPDATE SET ...括号位置略有不同冲突忽略ON CONFLICT (col) DO NOTHINGON CONFLICT(col) DO NOTHING相同引用新值EXCLUDED.colexcluded.colPostgreSQL 大写,SQLite 小写返回结果RETURNING *RETURNING *相同条件更新WHERE condition不支持 WHERESQLite 限制注意事项


  • 冲突列必须有唯一约束
  • PostgreSQL 和 SQLite 的语法相似,但仍有细微差别。使用原生 SQL 时需要注意。
  • SQLite 在 UPSERT 时不支持 WHERE 子句,需要改用 CASE 表达式或应用层过滤。
  • SQLite 3.35+ 版本才支持 RETURNING
EXCLUDED 和 RETURNING

EXCLUDED

EXCLUDED 表示冲突时被拦截的新值。
  1. INSERT INTO users (email, name, age)
  2. VALUES ('test@example.com', '新名字', 30)
  3. ON CONFLICT (email) DO UPDATE SET
  4.     name = EXCLUDED.name,   -- ← 引用新值 "新名字"
  5.     age = EXCLUDED.age      -- ← 引用新值 30
复制代码
场景表达式含义示例值原表字段users.name冲突行的当前值"老名字"新值字段EXCLUDED.name试图插入的新值"新名字"混合计算users.age + EXCLUDED.age原值 + 新值25 + 30 = 55示例 1:累加库存
  1. -- 商品库存累加:原库存 100 + 新增 50 = 150
  2. INSERT INTO products (sku, stock)
  3. VALUES ('IPHONE15', 50)
  4. ON CONFLICT (sku) DO UPDATE SET
  5.     stock = products.stock + EXCLUDED.stock  -- 100 + 50
  6. RETURNING stock;
复制代码
示例 2:仅更新非空字段
  1. -- 如果新值为 NULL,保留原值
  2. INSERT INTO users (email, name, age)
  3. VALUES ('test@example.com', '新名字', NULL)
  4. ON CONFLICT (email) DO UPDATE SET
  5.     name = COALESCE(EXCLUDED.name, users.name),  -- 新名字
  6.     age = COALESCE(EXCLUDED.age, users.age)      -- 保留原 age
复制代码
示例 3:时间戳更新
  1. -- 更新时刷新 updated_at
  2. INSERT INTO users (email, name)
  3. VALUES ('test@example.com', '新名字')
  4. ON CONFLICT (email) DO UPDATE SET
  5.     name = EXCLUDED.name,
  6.     updated_at = NOW()  -- PostgreSQL
  7.     -- updated_at = CURRENT_TIMESTAMP  -- SQLite
复制代码
RETURNING

RETURNING 用于返回操作结果。在 INSERT/UPDATE/DELETE 后直接返回指定列,避免额外 SELECT 查询:
  1. INSERT INTO users (email, name)
  2. VALUES ('test@example.com', '张三')
  3. RETURNING id, email, name, created_at;
复制代码
示例 1:插入后立即获取 ID
  1. # PostgreSQL / SQLite 3.35+
  2. sql = text("""
  3.     INSERT INTO users (email, name)
  4.     VALUES (:email, :name)
  5.     RETURNING id, email, created_at
  6. """)
  7. result = await session.execute(sql, {"email": "test@example.com", "name": "张三"})
  8. user = result.mappings().first()
  9. print(user["id"])  # 直接获取 ID
复制代码
示例 2:UPSERT 后统一返回
  1. -- 无论插入还是更新,都返回最终状态
  2. INSERT INTO users (email, name, login_count)
  3. VALUES ('test@example.com', '张三', 1)
  4. ON CONFLICT (email) DO UPDATE SET
  5.     name = EXCLUDED.name,
  6.     login_count = users.login_count + 1  -- 累加登录次数
  7. RETURNING
  8.     id,
  9.     email,
  10.     name,
  11.     login_count,
  12.     CASE
  13.         WHEN xmax = 0 THEN 'inserted'  -- PostgreSQL 特有:xmax=0 表示插入
  14.         ELSE 'updated'
  15.     END AS action
复制代码
示例 3:批量操作返回所有结果
  1. -- PostgreSQL 支持批量 RETURNING
  2. INSERT INTO users (email, name)
  3. VALUES
  4.     ('a@example.com', 'A'),
  5.     ('b@example.com', 'B')
  6. ON CONFLICT (email) DO UPDATE SET
  7.     name = EXCLUDED.name
  8. RETURNING id, email, name;
复制代码
Python 处理批量返回:
  1. result = await session.execute(sql)
  2. users = [dict(row) for row in result.mappings().all()]
  3. # [{'id': 1, 'email': 'a@example.com', 'name': 'A'}, ...]
复制代码
示例:用户登录计数器
  1. async def record_user_login(session: AsyncSession, email: str, name: str) -> dict:
  2.     """
  3.     用户登录计数器:
  4.     - 新用户:插入,login_count = 1
  5.     - 老用户:更新,login_count += 1
  6.     - 返回最终状态 + 操作类型
  7.     """
  8.     sql = text("""
  9.         INSERT INTO users (
  10.             email, name, login_count, last_login, created_at
  11.         ) VALUES (
  12.             :email, :name, 1, :now, :now
  13.         )
  14.         ON CONFLICT (email) DO UPDATE SET
  15.             name = EXCLUDED.name,                          -- 更新用户名
  16.             login_count = users.login_count + 1,           -- 累加登录次数
  17.             last_login = EXCLUDED.last_login               -- 更新最后登录时间
  18.         RETURNING
  19.             id,
  20.             email,
  21.             name,
  22.             login_count,
  23.             last_login,
  24.             created_at,
  25.             CASE
  26.                 WHEN xmax = 0 THEN 'inserted'
  27.                 ELSE 'updated'
  28.             END AS action  -- PostgreSQL 特有:区分插入/更新
  29.     """)
  30.    
  31.     now = datetime.utcnow()
  32.     result = await session.execute(
  33.         sql,
  34.         {"email": email, "name": name, "now": now}
  35.     )
  36.    
  37.     row = result.mappings().first()
  38.     return dict(row) if row else None
  39. # 使用示例
  40. user = await record_user_login(session, "test@example.com", "张三")
  41. print(f"{user['action']} user {user['email']} with {user['login_count']} logins")
  42. # 输出: inserted user test@example.com with 1 logins
  43. # 或: updated user test@example.com with 5 logins
复制代码
示例数据模型类
  1. from sqlalchemy import Column, Integer, String, UniqueConstraint
  2. from sqlalchemy.orm import DeclarativeBase
  3. class Base(DeclarativeBase):
  4.     pass
  5. class User(Base):
  6.     __tablename__ = "users"
  7.    
  8.     id = Column(Integer, primary_key=True, autoincrement=True)
  9.     email = Column(String(100), unique=True, nullable=False)  # 唯一约束
  10.     name = Column(String(50))
  11.     age = Column(Integer)
  12.     balance = Column(Integer, default=0)
  13.    
  14.     __table_args__ = (
  15.         UniqueConstraint("email", name="uq_users_email"),
  16.     )
  17. class Product(Base):
  18.     __tablename__ = "products"
  19.    
  20.     id = Column(Integer, primary_key=True)
  21.     sku = Column(String(50), unique=True, nullable=False)  # 唯一 SKU
  22.     name = Column(String(100))
  23.     stock = Column(Integer, default=0)
  24.     price = Column(Integer)
复制代码
ORM 方式

注意 insert 的导入路径。
基本示例
  1. from sqlalchemy.dialects.postgresql import insert as pg_insert
  2. from sqlalchemy.dialects.sqlite import insert as sqlite_insert
  3. from sqlalchemy import insert
  4. async def upsert_user_orm(session: AsyncSession, user_data: dict) -> dict:
  5.     """
  6.     UPSERT 用户(ORM 风格)
  7.     如果 email 冲突则更新,否则插入
  8.     """
  9.    
  10.     # 方式 1:使用通用 insert(推荐⭐)
  11.     # SQLAlchemy 会根据方言自动选择正确的语法
  12.     stmt = (
  13.         insert(User)
  14.         .values(**user_data)
  15.         .on_conflict_do_update(
  16.             index_elements=["email"],  # 冲突检测列(唯一约束)
  17.             set_={
  18.                 "name": user_data["name"],
  19.                 "age": user_data.get("age"),
  20.                 "updated_at": func.now()  # 假设有 updated_at 列
  21.             }
  22.         )
  23.         .returning(User)  # 返回插入/更新后的行
  24.     )
  25.    
  26.     result = await session.execute(stmt)
  27.     user = result.scalar_one()
  28.    
  29.     return {
  30.         "id": user.id,
  31.         "email": user.email,
  32.         "name": user.name,
  33.         "age": user.age
  34.     }
  35. async def upsert_user_ignore(session: AsyncSession, user_data: dict) -> bool:
  36.     """
  37.     UPSERT 但冲突时忽略(DO NOTHING)
  38.     """
  39.     stmt = (
  40.         insert(User)
  41.         .values(**user_data)
  42.         .on_conflict_do_nothing(
  43.             index_elements=["email"]
  44.         )
  45.     )
  46.    
  47.     result = await session.execute(stmt)
  48.     return result.rowcount > 0  # 返回是否插入成功
复制代码
条件更新:仅更新特定字段
  1. async def upsert_user_conditional(session: AsyncSession, user_data: dict) -> dict:
  2.     """
  3.     UPSERT:冲突时只更新非空字段
  4.     """
  5.     stmt = (
  6.         insert(User)
  7.         .values(**user_data)
  8.         .on_conflict_do_update(
  9.             index_elements=["email"],
  10.             set_={
  11.                 "name": user_data["name"],
  12.                 # 条件:只有提供了 age 才更新
  13.                 "age": user_data.get("age", User.age),  # 保持原值
  14.             },
  15.             # 可选:添加 WHERE 条件
  16.             where=User.email == user_data["email"]
  17.         )
  18.         .returning(User)
  19.     )
  20.    
  21.     result = await session.execute(stmt)
  22.     return result.mappings().first()
复制代码
批量 UPSERT
  1. async def bulk_upsert_users(session: AsyncSession, users: list[dict]) -> int:
  2.     """
  3.     批量 UPSERT 用户
  4.     """
  5.     stmt = (
  6.         insert(User)
  7.         .values(users)
  8.         .on_conflict_do_update(
  9.             index_elements=["email"],
  10.             set_={
  11.                 "name": insert(User).excluded.name,  # 使用 excluded 表示新值
  12.                 "age": insert(User).excluded.age,
  13.             }
  14.         )
  15.     )
  16.    
  17.     result = await session.execute(stmt)
  18.     return result.rowcount
复制代码
使用 EXCLUDED 引用新值
  1. async def upsert_product_with_stock(session: AsyncSession, product_data: dict) -> dict:
  2.     """
  3.     UPSERT 产品:冲突时累加库存
  4.     """
  5.     stmt = (
  6.         insert(Product)
  7.         .values(**product_data)
  8.         .on_conflict_do_update(
  9.             index_elements=["sku"],
  10.             set_={
  11.                 # 累加库存:原库存 + 新库存
  12.                 "stock": Product.stock + insert(Product).excluded.stock,
  13.                 # 更新其他字段
  14.                 "name": insert(Product).excluded.name,
  15.                 "price": insert(Product).excluded.price,
  16.             }
  17.         )
  18.         .returning(Product)
  19.     )
  20.    
  21.     result = await session.execute(stmt)
  22.     return result.mappings().first()
复制代码
用户服务
  1. class UserService:
  2.     """用户服务(支持 UPSERT)"""
  3.    
  4.     def __init__(self, session: AsyncSession):
  5.         self.session = session
  6.    
  7.     async def create_or_update(self, email: str, name: str, age: int | None = None) -> dict:
  8.         """创建或更新用户"""
  9.         stmt = (
  10.             insert(User)
  11.             .values(
  12.                 email=email,
  13.                 name=name,
  14.                 age=age,
  15.                 created_at=datetime.utcnow()
  16.             )
  17.             .on_conflict_do_update(
  18.                 index_elements=["email"],
  19.                 set_={
  20.                     "name": name,
  21.                     "age": age,
  22.                     "updated_at": datetime.utcnow()
  23.                 }
  24.             )
  25.             .returning(User)
  26.         )
  27.         
  28.         result = await self.session.execute(stmt)
  29.         user = result.scalar_one()
  30.         
  31.         return {
  32.             "id": user.id,
  33.             "email": user.email,
  34.             "name": user.name,
  35.             "age": user.age
  36.         }
  37.    
  38.     async def bulk_create_or_update(self, users: list[dict]) -> int:
  39.         """批量创建或更新"""
  40.         stmt = (
  41.             insert(User)
  42.             .values(users)
  43.             .on_conflict_do_update(
  44.                 index_elements=["email"],
  45.                 set_={
  46.                     "name": insert(User).excluded.name,
  47.                     "age": insert(User).excluded.age,
  48.                     "updated_at": datetime.utcnow()
  49.                 }
  50.             )
  51.         )
  52.         
  53.         result = await self.session.execute(stmt)
  54.         return result.rowcount
  55.    
  56.     async def create_if_not_exists(self, email: str, name: str) -> bool:
  57.         """仅当不存在时创建"""
  58.         stmt = (
  59.             insert(User)
  60.             .values(
  61.                 email=email,
  62.                 name=name,
  63.                 created_at=datetime.utcnow()
  64.             )
  65.             .on_conflict_do_nothing(
  66.                 index_elements=["email"]
  67.             )
  68.         )
  69.         
  70.         result = await self.session.execute(stmt)
  71.         return result.rowcount > 0  # True = 插入成功,False = 已存在
复制代码
原生 SQL

基本示例

PostgreSQL
  1. async def upsert_user_pg(session: AsyncSession, user_data: dict) -> dict | None:
  2.     """
  3.     PostgreSQL 原生 UPSERT
  4.     """
  5.     sql = text("""
  6.         INSERT INTO users (email, name, age, created_at)
  7.         VALUES (:email, :name, :age, :created_at)
  8.         ON CONFLICT (email) DO UPDATE  -- 冲突列
  9.         SET
  10.             name = EXCLUDED.name,      -- EXCLUDED 表示新插入的值
  11.             age = EXCLUDED.age,
  12.             updated_at = NOW()
  13.         RETURNING id, email, name, age
  14.     """)
  15.    
  16.     result = await session.execute(
  17.         sql,
  18.         {
  19.             "email": user_data["email"],
  20.             "name": user_data["name"],
  21.             "age": user_data.get("age"),
  22.             "created_at": datetime.utcnow()
  23.         }
  24.     )
  25.    
  26.     row = result.mappings().first()
  27.     return dict(row) if row else None
复制代码
SQLite
  1. async def upsert_user_sqlite(session: AsyncSession, user_data: dict) -> dict | None:
  2.     """
  3.     SQLite 原生 UPSERT(语法与 PostgreSQL 几乎相同)
  4.     """
  5.     sql = text("""
  6.         INSERT INTO users (email, name, age, created_at)
  7.         VALUES (:email, :name, :age, :created_at)
  8.         ON CONFLICT(email) DO UPDATE SET  -- SQLite 语法稍有不同
  9.             name = excluded.name,
  10.             age = excluded.age,
  11.             updated_at = CURRENT_TIMESTAMP
  12.         RETURNING id, email, name, age
  13.     """)
  14.    
  15.     result = await session.execute(
  16.         sql,
  17.         {
  18.             "email": user_data["email"],
  19.             "name": user_data["name"],
  20.             "age": user_data.get("age"),
  21.             "created_at": datetime.utcnow()
  22.         }
  23.     )
  24.    
  25.     row = result.mappings().first()
  26.     return dict(row) if row else None
复制代码
冲突时忽略
  1. async def insert_or_ignore_user(session: AsyncSession, user_data: dict) -> bool:
  2.     """
  3.     插入用户,如果冲突则忽略
  4.     """
  5.     # PostgreSQL
  6.     sql = text("""
  7.         INSERT INTO users (email, name, age, created_at)
  8.         VALUES (:email, :name, :age, :created_at)
  9.         ON CONFLICT (email) DO NOTHING
  10.     """)
  11.    
  12.     # SQLite(语法相同)
  13.     # sql = text("""
  14.     #     INSERT INTO users (email, name, age, created_at)
  15.     #     VALUES (:email, :name, :age, :created_at)
  16.     #     ON CONFLICT(email) DO NOTHING
  17.     # """)
  18.    
  19.     result = await session.execute(
  20.         sql,
  21.         {
  22.             "email": user_data["email"],
  23.             "name": user_data["name"],
  24.             "age": user_data.get("age"),
  25.             "created_at": datetime.utcnow()
  26.         }
  27.     )
  28.    
  29.     return result.rowcount > 0  # 返回是否插入成功
复制代码
批量 UPSERT
  1. async def bulk_upsert_products(session: AsyncSession, products: list[dict]) -> int:
  2.     """
  3.     批量 UPSERT 产品(原生 SQL)
  4.     """
  5.     # PostgreSQL
  6.     sql = text("""
  7.         INSERT INTO products (sku, name, stock, price, created_at)
  8.         VALUES (
  9.             :sku, :name, :stock, :price, :created_at
  10.         )
  11.         ON CONFLICT (sku) DO UPDATE SET
  12.             name = EXCLUDED.name,
  13.             stock = products.stock + EXCLUDED.stock,  -- 累加库存
  14.             price = EXCLUDED.price,
  15.             updated_at = NOW()
  16.     """)
  17.    
  18.     # 批量执行
  19.     for product in products:
  20.         await session.execute(
  21.             sql,
  22.             {
  23.                 "sku": product["sku"],
  24.                 "name": product["name"],
  25.                 "stock": product.get("stock", 0),
  26.                 "price": product.get("price", 0),
  27.                 "created_at": datetime.utcnow()
  28.             }
  29.         )
  30.    
  31.     return len(products)
复制代码
部分更新 + 条件判断
  1. async def upsert_user_smart(session: AsyncSession, user_data: dict) -> dict | None:
  2.     """
  3.     智能 UPSERT:
  4.     - 如果提供了 age,才更新 age
  5.     - 如果提供了 name,才更新 name
  6.     - 更新 updated_at
  7.     """
  8.     sql = text("""
  9.         INSERT INTO users (email, name, age, created_at)
  10.         VALUES (:email, :name, :age, :created_at)
  11.         ON CONFLICT (email) DO UPDATE SET
  12.             name = COALESCE(:name, users.name),  -- 如果新值为 NULL,保持原值
  13.             age = COALESCE(:age, users.age),
  14.             updated_at = NOW()
  15.         RETURNING id, email, name, age, updated_at
  16.     """)
  17.    
  18.     result = await session.execute(
  19.         sql,
  20.         {
  21.             "email": user_data["email"],
  22.             "name": user_data.get("name"),  # 可能为 None
  23.             "age": user_data.get("age"),    # 可能为 None
  24.             "created_at": datetime.utcnow()
  25.         }
  26.     )
  27.    
  28.     row = result.mappings().first()
  29.     return dict(row) if row else None
复制代码
用户注册/登录:存在则更新最后登录时间
  1. async def register_or_login(session: AsyncSession, email: str, name: str) -> dict:
  2.     """
  3.     用户注册或登录:
  4.     - 新用户:插入
  5.     - 老用户:更新最后登录时间
  6.     """
  7.     sql = text("""
  8.         INSERT INTO users (email, name, last_login, created_at)
  9.         VALUES (:email, :name, :now, :now)
  10.         ON CONFLICT (email) DO UPDATE SET
  11.             last_login = EXCLUDED.last_login,
  12.             name = EXCLUDED.name  -- 可选:更新用户名
  13.         RETURNING id, email, name, last_login, created_at
  14.     """)
  15.    
  16.     now = datetime.utcnow()
  17.     result = await session.execute(
  18.         sql,
  19.         {"email": email, "name": name, "now": now}
  20.     )
  21.    
  22.     return dict(result.mappings().first())
复制代码
库存累加
  1. async def add_product_stock(session: AsyncSession, sku: str, quantity: int) -> bool:
  2.     """
  3.     增加商品库存:
  4.     - 商品不存在:插入
  5.     - 商品存在:累加库存
  6.     """
  7.     sql = text("""
  8.         INSERT INTO products (sku, stock, created_at)
  9.         VALUES (:sku, :quantity, :now)
  10.         ON CONFLICT (sku) DO UPDATE SET
  11.             stock = products.stock + EXCLUDED.stock,
  12.             updated_at = NOW()
  13.     """)
  14.    
  15.     result = await session.execute(
  16.         sql,
  17.         {
  18.             "sku": sku,
  19.             "quantity": quantity,
  20.             "now": datetime.utcnow()
  21.         }
  22.     )
  23.    
  24.     return result.rowcount > 0
复制代码
用户积分累加
  1. async def add_user_points(session: AsyncSession, user_id: int, points: int) -> dict | None:
  2.     """
  3.     增加用户积分(累加)
  4.     """
  5.     sql = text("""
  6.         INSERT INTO user_points (user_id, points, created_at)
  7.         VALUES (:user_id, :points, :now)
  8.         ON CONFLICT (user_id) DO UPDATE SET
  9.             points = user_points.points + EXCLUDED.points,
  10.             updated_at = NOW()
  11.         RETURNING user_id, points
  12.     """)
  13.    
  14.     result = await session.execute(
  15.         sql,
  16.         {
  17.             "user_id": user_id,
  18.             "points": points,
  19.             "now": datetime.utcnow()
  20.         }
  21.     )
  22.    
  23.     row = result.mappings().first()
  24.     return dict(row) if row else None
复制代码
标签计数

存在则 +1,不存在则创建:
  1. async def increment_tag_count(session: AsyncSession, tag_name: str) -> int:
  2.     """
  3.     标签计数:
  4.     - 标签不存在:插入 count=1
  5.     - 标签存在:count += 1
  6.     """
  7.     sql = text("""
  8.         INSERT INTO tags (name, count, created_at)
  9.         VALUES (:name, 1, :now)
  10.         ON CONFLICT (name) DO UPDATE SET
  11.             count = tags.count + 1,
  12.             updated_at = NOW()
  13.         RETURNING count
  14.     """)
  15.    
  16.     result = await session.execute(
  17.         sql,
  18.         {"name": tag_name, "now": datetime.utcnow()}
  19.     )
  20.    
  21.     return result.scalar() or 0
复制代码
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

您需要登录后才可以回帖 登录 | 立即注册