SQLAlchemy ORM 框架详解
目录
- SQLAlchemy 简介
- 安装与环境准备
- 核心概念
- 数据库连接与配置
- 定义模型(ORM 映射)
- 创建表结构
- 基本 CRUD 操作
- 关系映射
- 查询操作详解
- 事务处理
- 测试数据集
- 完整示例项目
- 最佳实践
SQLAlchemy 简介
SQLAlchemy 是 Python 中最流行的 ORM(Object-Relational Mapping,对象关系映射)框架之一。它提供了:
- 强大的 ORM:将数据库表映射为 Python 类,行映射为对象实例
- SQL 表达式语言:可以编写原生 SQL 查询
- 数据库抽象:支持多种数据库(MySQL、PostgreSQL、SQLite、Oracle 等)
- 连接池管理:自动管理数据库连接
- 迁移支持:与 Alembic 集成,支持数据库版本控制
SQLAlchemy 的两种使用方式
- Core(核心层):直接使用 SQL 表达式语言,更接近原生 SQL
- ORM(对象关系映射):使用 Python 类映射数据库表,更符合面向对象编程
本文主要介绍 ORM 方式的使用。
安装与环境准备
安装 SQLAlchemy
pip install sqlalchemy
如果需要连接 MySQL,还需要安装对应的数据库驱动:
# 使用 PyMySQL(纯 Python 实现)
pip install pymysql
# 或使用 mysqlclient(需要编译)
pip install mysqlclient
安装其他可选依赖
# 用于数据库迁移
pip install alembic
# 用于连接池
pip install sqlalchemy[pool]
核心概念
在深入学习之前,需要理解 SQLAlchemy 的核心概念:
- Engine(引擎):数据库连接的核心,管理连接池
- Session(会话):数据库操作的上下文,管理事务
- Base(基类):所有模型类的基类
- Model(模型):映射到数据库表的 Python 类
- Query(查询):用于查询数据库的接口
数据库连接与配置
1. 创建数据库引擎
from sqlalchemy import create_engine
# MySQL 连接字符串格式
# mysql+pymysql://用户名:密码@主机:端口/数据库名?charset=utf8mb4
DATABASE_URL = "mysql+pymysql://testuser:testpass123@localhost:3306/testdb?charset=utf8mb4"
# 创建引擎
engine = create_engine(
DATABASE_URL,
echo=True, # 打印 SQL 语句(调试用)
pool_size=5, # 连接池大小
max_overflow=10, # 最大溢出连接数
pool_pre_ping=True, # 连接前检查连接是否有效
pool_recycle=3600, # 连接回收时间(秒)
)
# 测试连接
with engine.connect() as conn:
result = conn.execute("SELECT 1")
print("数据库连接成功!")
2. 使用配置文件
# config.py
import os
from sqlalchemy import create_engine
class DatabaseConfig:
"""数据库配置类"""
DB_HOST = os.getenv('DB_HOST', 'localhost')
DB_PORT = os.getenv('DB_PORT', '3306')
DB_USER = os.getenv('DB_USER', 'testuser')
DB_PASSWORD = os.getenv('DB_PASSWORD', 'testpass123')
DB_NAME = os.getenv('DB_NAME', 'testdb')
@classmethod
def get_database_url(cls):
"""获取数据库连接字符串"""
return (
f"mysql+pymysql://{cls.DB_USER}:{cls.DB_PASSWORD}"
f"@{cls.DB_HOST}:{cls.DB_PORT}/{cls.DB_NAME}?charset=utf8mb4"
)
@classmethod
def create_engine(cls, echo=False):
"""创建数据库引擎"""
return create_engine(
cls.get_database_url(),
echo=echo,
pool_size=5,
max_overflow=10,
pool_pre_ping=True,
pool_recycle=3600,
)
# 使用
engine = DatabaseConfig.create_engine(echo=True)
3. SQLite 连接(用于测试)
from sqlalchemy import create_engine
# SQLite 连接(不需要服务器,适合测试)
engine = create_engine(
'sqlite:///test.db', # 相对路径
# 'sqlite:////absolute/path/to/test.db', # 绝对路径
echo=True,
connect_args={'check_same_thread': False} # SQLite 特有参数
)
定义模型(ORM 映射)
1. 创建基类和会话工厂
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
# 创建基类
Base = declarative_base()
# 创建会话工厂
SessionLocal = sessionmaker(
bind=engine,
autocommit=False,
autoflush=False
)
# 创建会话的辅助函数
def get_session():
"""获取数据库会话"""
session = SessionLocal()
try:
yield session
finally:
session.close()
2. 定义简单的模型
from sqlalchemy import Column, Integer, String, DateTime, Text
from sqlalchemy.sql import func
from datetime import datetime
class User(Base):
"""用户模型"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True)
username = Column(String(50), unique=True, nullable=False, index=True)
email = Column(String(100), unique=True, nullable=False, index=True)
password_hash = Column(String(255), nullable=False)
full_name = Column(String(100))
age = Column(Integer)
bio = Column(Text)
is_active = Column(Integer, default=1) # 1=激活, 0=禁用
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
def __repr__(self):
return f"<User(id={self.id}, username='{self.username}', email='{self.email}')>"
3. 定义带关系的模型
from sqlalchemy import Column, Integer, String, ForeignKey, Text, DateTime, Enum
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
import enum
class PostStatus(enum.Enum):
"""文章状态枚举"""
DRAFT = "draft"
PUBLISHED = "published"
ARCHIVED = "archived"
class Category(Base):
"""分类模型"""
__tablename__ = 'categories'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(50), unique=True, nullable=False)
description = Column(Text)
created_at = Column(DateTime, default=func.now())
# 关系:一个分类有多个文章
posts = relationship("Post", back_populates="category")
def __repr__(self):
return f"<Category(id={self.id}, name='{self.name}')>"
class Post(Base):
"""文章模型"""
__tablename__ = 'posts'
id = Column(Integer, primary_key=True, autoincrement=True)
title = Column(String(200), nullable=False, index=True)
content = Column(Text)
status = Column(Enum(PostStatus), default=PostStatus.DRAFT)
view_count = Column(Integer, default=0)
# 外键
author_id = Column(Integer, ForeignKey('users.id'), nullable=False)
category_id = Column(Integer, ForeignKey('categories.id'))
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
# 关系:多对一
author = relationship("User", back_populates="posts")
category = relationship("Category", back_populates="posts")
# 关系:多对多(通过中间表)
tags = relationship("Tag", secondary="post_tags", back_populates="posts")
def __repr__(self):
return f"<Post(id={self.id}, title='{self.title}', status='{self.status.value}')>"
class Tag(Base):
"""标签模型"""
__tablename__ = 'tags'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(50), unique=True, nullable=False)
created_at = Column(DateTime, default=func.now())
# 关系:多对多
posts = relationship("Post", secondary="post_tags", back_populates="tags")
def __repr__(self):
return f"<Tag(id={self.id}, name='{self.name}')>"
# 多对多中间表
post_tags = Table(
'post_tags',
Base.metadata,
Column('post_id', Integer, ForeignKey('posts.id'), primary_key=True),
Column('tag_id', Integer, ForeignKey('tags.id'), primary_key=True)
)
# 更新 User 模型,添加关系
User.posts = relationship("Post", back_populates="author")
创建表结构
1. 创建所有表
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
# 创建引擎
engine = create_engine("mysql+pymysql://testuser:testpass123@localhost:3306/testdb?charset=utf8mb4")
# 创建基类
Base = declarative_base()
# 导入所有模型(确保所有模型类都已定义)
# from models import User, Post, Category, Tag, post_tags
# 创建所有表
Base.metadata.create_all(engine)
print("所有表创建成功!")
2. 删除所有表(谨慎使用)
# 删除所有表
Base.metadata.drop_all(engine)
print("所有表已删除!")
3. 检查表是否存在
from sqlalchemy import inspect
inspector = inspect(engine)
tables = inspector.get_table_names()
print(f"数据库中的表: {tables}")
基本 CRUD 操作
1. 创建(Create)
from sqlalchemy.orm import Session
def create_user(session: Session, username: str, email: str, password_hash: str,
full_name: str = None, age: int = None):
"""创建用户"""
user = User(
username=username,
email=email,
password_hash=password_hash,
full_name=full_name,
age=age
)
session.add(user)
session.commit()
session.refresh(user) # 刷新以获取数据库生成的 ID
print(f"用户创建成功: {user}")
return user
# 使用示例
session = SessionLocal()
try:
user = create_user(
session,
username='alice',
email='alice@example.com',
password_hash='hashed_password_123',
full_name='Alice Smith',
age=25
)
finally:
session.close()
2. 批量创建
def create_users_batch(session: Session, users_data: list):
"""批量创建用户"""
users = [
User(
username=data['username'],
email=data['email'],
password_hash=data['password_hash'],
full_name=data.get('full_name'),
age=data.get('age')
)
for data in users_data
]
session.add_all(users)
session.commit()
print(f"批量创建了 {len(users)} 个用户")
return users
# 使用示例
users_data = [
{
'username': 'bob',
'email': 'bob@example.com',
'password_hash': 'hash123',
'full_name': 'Bob Johnson',
'age': 30
},
{
'username': 'charlie',
'email': 'charlie@example.com',
'password_hash': 'hash456',
'full_name': 'Charlie Brown',
'age': 28
}
]
session = SessionLocal()
try:
users = create_users_batch(session, users_data)
finally:
session.close()
3. 读取(Read)
查询单个对象
def get_user_by_id(session: Session, user_id: int):
"""根据 ID 获取用户"""
user = session.query(User).filter(User.id == user_id).first()
return user
def get_user_by_username(session: Session, username: str):
"""根据用户名获取用户"""
user = session.query(User).filter(User.username == username).first()
return user
# 使用示例
session = SessionLocal()
try:
user = get_user_by_id(session, 1)
if user:
print(f"找到用户: {user}")
else:
print("用户不存在")
finally:
session.close()
查询多个对象
def get_all_users(session: Session, limit: int = None):
"""获取所有用户"""
query = session.query(User)
if limit:
query = query.limit(limit)
return query.all()
def get_active_users(session: Session):
"""获取所有激活的用户"""
return session.query(User).filter(User.is_active == 1).all()
# 使用示例
session = SessionLocal()
try:
users = get_all_users(session, limit=10)
for user in users:
print(user)
finally:
session.close()
4. 更新(Update)
def update_user(session: Session, user_id: int, **kwargs):
"""更新用户信息"""
user = session.query(User).filter(User.id == user_id).first()
if not user:
print(f"用户 ID {user_id} 不存在")
return None
# 更新字段
for key, value in kwargs.items():
if hasattr(user, key):
setattr(user, key, value)
session.commit()
session.refresh(user)
print(f"用户更新成功: {user}")
return user
# 使用示例
session = SessionLocal()
try:
update_user(session, 1, age=26, full_name='Alice Johnson')
finally:
session.close()
5. 删除(Delete)
def delete_user(session: Session, user_id: int):
"""删除用户"""
user = session.query(User).filter(User.id == user_id).first()
if not user:
print(f"用户 ID {user_id} 不存在")
return False
session.delete(user)
session.commit()
print(f"用户 {user_id} 删除成功")
return True
# 使用示例
session = SessionLocal()
try:
delete_user(session, 1)
finally:
session.close()
关系映射
1. 一对多关系
# 创建文章(一对多:一个用户有多篇文章)
def create_post(session: Session, title: str, content: str, author_id: int,
category_id: int = None):
"""创建文章"""
post = Post(
title=title,
content=content,
author_id=author_id,
category_id=category_id,
status=PostStatus.PUBLISHED
)
session.add(post)
session.commit()
session.refresh(post)
return post
# 通过关系访问
def get_user_posts(session: Session, user_id: int):
"""获取用户的所有文章"""
user = session.query(User).filter(User.id == user_id).first()
if user:
return user.posts # 通过关系访问
return []
# 使用示例
session = SessionLocal()
try:
# 创建文章
post = create_post(
session,
title='Python 入门教程',
content='这是一篇关于 Python 的教程...',
author_id=1
)
# 通过用户对象访问文章
user = session.query(User).filter(User.id == 1).first()
if user:
print(f"用户 {user.username} 的文章:")
for post in user.posts:
print(f" - {post.title}")
finally:
session.close()
2. 多对多关系
def add_tags_to_post(session: Session, post_id: int, tag_names: list):
"""给文章添加标签"""
post = session.query(Post).filter(Post.id == post_id).first()
if not post:
print(f"文章 {post_id} 不存在")
return
# 查找或创建标签
tags = []
for tag_name in tag_names:
tag = session.query(Tag).filter(Tag.name == tag_name).first()
if not tag:
tag = Tag(name=tag_name)
session.add(tag)
tags.append(tag)
# 添加标签到文章
post.tags.extend(tags)
session.commit()
print(f"文章 {post_id} 添加了标签: {tag_names}")
# 使用示例
session = SessionLocal()
try:
add_tags_to_post(session, 1, ['Python', '教程', '入门'])
finally:
session.close()
查询操作详解
1. 基本查询
# 查询所有记录
users = session.query(User).all()
# 查询第一条记录
user = session.query(User).first()
# 查询一条记录(如果不存在会抛出异常)
user = session.query(User).one()
# 查询一条或没有(如果多条会抛出异常)
user = session.query(User).one_or_none()
2. 过滤查询
# 等于
users = session.query(User).filter(User.age == 25).all()
# 不等于
users = session.query(User).filter(User.age != 25).all()
# 大于、小于
users = session.query(User).filter(User.age > 25).all()
users = session.query(User).filter(User.age < 30).all()
users = session.query(User).filter(User.age >= 25).all()
users = session.query(User).filter(User.age <= 30).all()
# IN
users = session.query(User).filter(User.age.in_([25, 30, 35])).all()
# LIKE
users = session.query(User).filter(User.username.like('%alice%')).all()
# 多个条件(AND)
users = session.query(User).filter(
User.age >= 25,
User.is_active == 1
).all()
# 多个条件(OR)
from sqlalchemy import or_
users = session.query(User).filter(
or_(User.age < 25, User.age > 35)
).all()
# 空值检查
users = session.query(User).filter(User.full_name.is_(None)).all()
users = session.query(User).filter(User.full_name.isnot(None)).all()
3. 排序和限制
# 排序
users = session.query(User).order_by(User.age).all() # 升序
users = session.query(User).order_by(User.age.desc()).all() # 降序
# 多字段排序
users = session.query(User).order_by(User.age, User.username).all()
# 限制数量
users = session.query(User).limit(10).all()
# 偏移量(分页)
page = 1
page_size = 10
offset = (page - 1) * page_size
users = session.query(User).offset(offset).limit(page_size).all()
# 计数
count = session.query(User).count()
active_count = session.query(User).filter(User.is_active == 1).count()
4. 连接查询
# 内连接
posts = session.query(Post).join(User).filter(User.username == 'alice').all()
# 左连接
from sqlalchemy import outerjoin
posts = session.query(Post).outerjoin(Category).all()
# 使用关系进行连接
posts = session.query(Post).join(Post.author).filter(User.username == 'alice').all()
5. 聚合查询
from sqlalchemy import func
# 平均值
avg_age = session.query(func.avg(User.age)).scalar()
# 最大值、最小值
max_age = session.query(func.max(User.age)).scalar()
min_age = session.query(func.min(User.age)).scalar()
# 求和
total_views = session.query(func.sum(Post.view_count)).scalar()
# 分组
from sqlalchemy import func
result = session.query(
User.id,
User.username,
func.count(Post.id).label('post_count')
).join(Post).group_by(User.id).all()
for user_id, username, post_count in result:
print(f"{username}: {post_count} 篇文章")
6. 子查询
from sqlalchemy import func
# 子查询:查找文章数最多的用户
subquery = session.query(
Post.author_id,
func.count(Post.id).label('post_count')
).group_by(Post.author_id).subquery()
result = session.query(
User.username,
subquery.c.post_count
).join(subquery, User.id == subquery.c.author_id).order_by(
subquery.c.post_count.desc()
).all()
7. 预加载(Eager Loading)
from sqlalchemy.orm import joinedload, selectinload
# 使用 joinedload(使用 JOIN)
user = session.query(User).options(joinedload(User.posts)).filter(User.id == 1).first()
# 访问 user.posts 不会触发额外的查询
# 使用 selectinload(使用子查询)
user = session.query(User).options(selectinload(User.posts)).filter(User.id == 1).first()
# 加载多个关系
post = session.query(Post).options(
joinedload(Post.author),
joinedload(Post.category),
selectinload(Post.tags)
).filter(Post.id == 1).first()
事务处理
1. 基本事务
session = SessionLocal()
try:
# 开始事务(自动)
user1 = User(username='user1', email='user1@example.com', password_hash='hash1')
user2 = User(username='user2', email='user2@example.com', password_hash='hash2')
session.add(user1)
session.add(user2)
# 提交事务
session.commit()
print("事务提交成功")
except Exception as e:
# 回滚事务
session.rollback()
print(f"事务回滚: {e}")
finally:
session.close()
2. 嵌套事务(保存点)
session = SessionLocal()
try:
user = User(username='user1', email='user1@example.com', password_hash='hash1')
session.add(user)
session.flush() # 刷新到数据库但不提交
# 创建保存点
savepoint = session.begin_nested()
try:
post = Post(title='Test', content='Content', author_id=user.id)
session.add(post)
savepoint.commit()
except Exception:
savepoint.rollback()
raise
session.commit()
except Exception as e:
session.rollback()
print(f"错误: {e}")
finally:
session.close()
测试数据集
DDL 语句(数据定义语言)
以下是完整的测试数据库表结构 SQL 语句:
-- 创建数据库
CREATE DATABASE IF NOT EXISTS testdb CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
USE testdb;
-- 用户表
CREATE TABLE IF NOT EXISTS users (
id INT AUTO_INCREMENT PRIMARY KEY,
username VARCHAR(50) NOT NULL UNIQUE,
email VARCHAR(100) NOT NULL UNIQUE,
password_hash VARCHAR(255) NOT NULL,
full_name VARCHAR(100),
age INT,
bio TEXT,
is_active INT DEFAULT 1,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_username (username),
INDEX idx_email (email),
INDEX idx_is_active (is_active)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 分类表
CREATE TABLE IF NOT EXISTS categories (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(50) NOT NULL UNIQUE,
description TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 文章表
CREATE TABLE IF NOT EXISTS posts (
id INT AUTO_INCREMENT PRIMARY KEY,
title VARCHAR(200) NOT NULL,
content TEXT,
status ENUM('draft', 'published', 'archived') DEFAULT 'draft',
view_count INT DEFAULT 0,
author_id INT NOT NULL,
category_id INT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
INDEX idx_title (title),
INDEX idx_author_id (author_id),
INDEX idx_category_id (category_id),
INDEX idx_status (status),
FOREIGN KEY (author_id) REFERENCES users(id) ON DELETE CASCADE,
FOREIGN KEY (category_id) REFERENCES categories(id) ON DELETE SET NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 标签表
CREATE TABLE IF NOT EXISTS tags (
id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(50) NOT NULL UNIQUE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
-- 文章标签关联表(多对多)
CREATE TABLE IF NOT EXISTS post_tags (
post_id INT NOT NULL,
tag_id INT NOT NULL,
PRIMARY KEY (post_id, tag_id),
FOREIGN KEY (post_id) REFERENCES posts(id) ON DELETE CASCADE,
FOREIGN KEY (tag_id) REFERENCES tags(id) ON DELETE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
测试数据(DML 语句)
以下是测试数据的插入语句:
-- 插入用户数据
INSERT INTO users (username, email, password_hash, full_name, age, bio, is_active) VALUES
('alice', 'alice@example.com', 'hashed_password_alice', 'Alice Smith', 25, 'Python 开发者,热爱编程', 1),
('bob', 'bob@example.com', 'hashed_password_bob', 'Bob Johnson', 30, '全栈工程师', 1),
('charlie', 'charlie@example.com', 'hashed_password_charlie', 'Charlie Brown', 28, '数据科学家', 1),
('david', 'david@example.com', 'hashed_password_david', 'David Wilson', 32, 'DevOps 工程师', 1),
('eve', 'eve@example.com', 'hashed_password_eve', 'Eve Davis', 26, '前端开发工程师', 1),
('frank', 'frank@example.com', 'hashed_password_frank', 'Frank Miller', 35, '架构师', 1),
('grace', 'grace@example.com', 'hashed_password_grace', 'Grace Lee', 24, 'UI/UX 设计师', 1),
('henry', 'henry@example.com', 'hashed_password_henry', 'Henry Taylor', 29, '移动应用开发者', 1);
-- 插入分类数据
INSERT INTO categories (name, description) VALUES
('Python', 'Python 编程相关文章'),
('JavaScript', 'JavaScript 和前端开发'),
('数据库', '数据库相关技术'),
('DevOps', 'DevOps 和运维'),
('算法', '算法和数据结构'),
('架构设计', '系统架构设计');
-- 插入文章数据
INSERT INTO posts (title, content, status, view_count, author_id, category_id) VALUES
('Python 入门教程', '这是一篇详细的 Python 入门教程,适合初学者...', 'published', 150, 1, 1),
('SQLAlchemy 使用指南', '深入讲解 SQLAlchemy ORM 框架的使用方法...', 'published', 89, 1, 1),
('JavaScript 异步编程', 'Promise、async/await 详解...', 'published', 120, 2, 2),
('MySQL 性能优化', '数据库查询优化技巧和实践...', 'published', 200, 3, 3),
('Docker 容器化部署', '使用 Docker 进行应用部署...', 'published', 95, 4, 4),
('快速排序算法详解', '快速排序的原理和实现...', 'published', 75, 5, 5),
('微服务架构设计', '微服务架构的最佳实践...', 'published', 180, 6, 6),
('Python 装饰器详解', '深入理解 Python 装饰器...', 'draft', 0, 1, 1),
('React Hooks 使用', 'React Hooks 的完整指南...', 'draft', 0, 2, 2);
-- 插入标签数据
INSERT INTO tags (name) VALUES
('Python'),
('教程'),
('入门'),
('ORM'),
('数据库'),
('JavaScript'),
('前端'),
('优化'),
('Docker'),
('算法'),
('架构'),
('微服务');
-- 插入文章标签关联数据
INSERT INTO post_tags (post_id, tag_id) VALUES
(1, 1), (1, 2), (1, 3), -- Python 入门教程: Python, 教程, 入门
(2, 1), (2, 4), (2, 5), -- SQLAlchemy 使用指南: Python, ORM, 数据库
(3, 6), (3, 7), -- JavaScript 异步编程: JavaScript, 前端
(4, 5), (4, 8), -- MySQL 性能优化: 数据库, 优化
(5, 9), -- Docker 容器化部署: Docker
(6, 10), -- 快速排序算法详解: 算法
(7, 11), (7, 12); -- 微服务架构设计: 架构, 微服务
完整示例项目
以下是一个完整的 SQLAlchemy 示例项目:
"""
SQLAlchemy 完整示例项目
包含模型定义、CRUD 操作、关系映射、查询等
"""
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, Enum, Table
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.sql import func
import enum
from datetime import datetime
# ==================== 配置 ====================
DATABASE_URL = "mysql+pymysql://testuser:testpass123@localhost:3306/testdb?charset=utf8mb4"
engine = create_engine(DATABASE_URL, echo=True)
Base = declarative_base()
SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
# ==================== 枚举 ====================
class PostStatus(enum.Enum):
DRAFT = "draft"
PUBLISHED = "published"
ARCHIVED = "archived"
# ==================== 模型定义 ====================
# 多对多中间表
post_tags = Table(
'post_tags',
Base.metadata,
Column('post_id', Integer, ForeignKey('posts.id'), primary_key=True),
Column('tag_id', Integer, ForeignKey('tags.id'), primary_key=True)
)
class User(Base):
"""用户模型"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True)
username = Column(String(50), unique=True, nullable=False, index=True)
email = Column(String(100), unique=True, nullable=False, index=True)
password_hash = Column(String(255), nullable=False)
full_name = Column(String(100))
age = Column(Integer)
bio = Column(Text)
is_active = Column(Integer, default=1)
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
# 关系
posts = relationship("Post", back_populates="author")
def __repr__(self):
return f"<User(id={self.id}, username='{self.username}')>"
class Category(Base):
"""分类模型"""
__tablename__ = 'categories'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(50), unique=True, nullable=False)
description = Column(Text)
created_at = Column(DateTime, default=func.now())
posts = relationship("Post", back_populates="category")
def __repr__(self):
return f"<Category(id={self.id}, name='{self.name}')>"
class Post(Base):
"""文章模型"""
__tablename__ = 'posts'
id = Column(Integer, primary_key=True, autoincrement=True)
title = Column(String(200), nullable=False, index=True)
content = Column(Text)
status = Column(Enum(PostStatus), default=PostStatus.DRAFT)
view_count = Column(Integer, default=0)
author_id = Column(Integer, ForeignKey('users.id'), nullable=False)
category_id = Column(Integer, ForeignKey('categories.id'))
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
# 关系
author = relationship("User", back_populates="posts")
category = relationship("Category", back_populates="posts")
tags = relationship("Tag", secondary=post_tags, back_populates="posts")
def __repr__(self):
return f"<Post(id={self.id}, title='{self.title}')>"
class Tag(Base):
"""标签模型"""
__tablename__ = 'tags'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(50), unique=True, nullable=False)
created_at = Column(DateTime, default=func.now())
posts = relationship("Post", secondary=post_tags, back_populates="tags")
def __repr__(self):
return f"<Tag(id={self.id}, name='{self.name}')>"
# ==================== 数据访问层 ====================
class UserDAO:
"""用户数据访问对象"""
def __init__(self, session):
self.session = session
def create(self, username, email, password_hash, full_name=None, age=None, bio=None):
"""创建用户"""
user = User(
username=username,
email=email,
password_hash=password_hash,
full_name=full_name,
age=age,
bio=bio
)
self.session.add(user)
self.session.commit()
self.session.refresh(user)
return user
def get_by_id(self, user_id):
"""根据 ID 获取用户"""
return self.session.query(User).filter(User.id == user_id).first()
def get_by_username(self, username):
"""根据用户名获取用户"""
return self.session.query(User).filter(User.username == username).first()
def get_all(self, limit=None):
"""获取所有用户"""
query = self.session.query(User)
if limit:
query = query.limit(limit)
return query.all()
def update(self, user_id, **kwargs):
"""更新用户"""
user = self.get_by_id(user_id)
if not user:
return None
for key, value in kwargs.items():
if hasattr(user, key):
setattr(user, key, value)
self.session.commit()
self.session.refresh(user)
return user
def delete(self, user_id):
"""删除用户"""
user = self.get_by_id(user_id)
if user:
self.session.delete(user)
self.session.commit()
return True
return False
class PostDAO:
"""文章数据访问对象"""
def __init__(self, session):
self.session = session
def create(self, title, content, author_id, category_id=None, status=PostStatus.PUBLISHED):
"""创建文章"""
post = Post(
title=title,
content=content,
author_id=author_id,
category_id=category_id,
status=status
)
self.session.add(post)
self.session.commit()
self.session.refresh(post)
return post
def get_by_id(self, post_id):
"""根据 ID 获取文章"""
return self.session.query(Post).filter(Post.id == post_id).first()
def get_published(self, limit=None):
"""获取已发布的文章"""
query = self.session.query(Post).filter(Post.status == PostStatus.PUBLISHED)
if limit:
query = query.limit(limit)
return query.order_by(Post.created_at.desc()).all()
def get_by_author(self, author_id):
"""获取作者的所有文章"""
return self.session.query(Post).filter(Post.author_id == author_id).all()
def add_tags(self, post_id, tag_names):
"""给文章添加标签"""
post = self.get_by_id(post_id)
if not post:
return False
for tag_name in tag_names:
tag = self.session.query(Tag).filter(Tag.name == tag_name).first()
if not tag:
tag = Tag(name=tag_name)
self.session.add(tag)
if tag not in post.tags:
post.tags.append(tag)
self.session.commit()
return True
def increment_view(self, post_id):
"""增加浏览量"""
post = self.get_by_id(post_id)
if post:
post.view_count += 1
self.session.commit()
return True
return False
# ==================== 主程序 ====================
def main():
"""主函数 - 演示所有功能"""
# 创建表
print("=" * 60)
print("1. 创建数据库表")
print("=" * 60)
Base.metadata.create_all(engine)
print("表创建成功!\n")
session = SessionLocal()
try:
# 创建 DAO 实例
user_dao = UserDAO(session)
post_dao = PostDAO(session)
# 创建用户
print("=" * 60)
print("2. 创建用户")
print("=" * 60)
user1 = user_dao.create(
username='alice',
email='alice@example.com',
password_hash='hash123',
full_name='Alice Smith',
age=25,
bio='Python 开发者'
)
print(f"创建用户: {user1}\n")
user2 = user_dao.create(
username='bob',
email='bob@example.com',
password_hash='hash456',
full_name='Bob Johnson',
age=30
)
print(f"创建用户: {user2}\n")
# 创建分类
print("=" * 60)
print("3. 创建分类")
print("=" * 60)
category1 = Category(name='Python', description='Python 编程')
category2 = Category(name='JavaScript', description='JavaScript 开发')
session.add_all([category1, category2])
session.commit()
print(f"创建分类: {category1}, {category2}\n")
# 创建文章
print("=" * 60)
print("4. 创建文章")
print("=" * 60)
post1 = post_dao.create(
title='Python 入门教程',
content='这是一篇 Python 入门教程...',
author_id=user1.id,
category_id=category1.id
)
print(f"创建文章: {post1}\n")
post2 = post_dao.create(
title='JavaScript 异步编程',
content='Promise 和 async/await 详解...',
author_id=user2.id,
category_id=category2.id
)
print(f"创建文章: {post2}\n")
# 添加标签
print("=" * 60)
print("5. 给文章添加标签")
print("=" * 60)
post_dao.add_tags(post1.id, ['Python', '教程', '入门'])
post_dao.add_tags(post2.id, ['JavaScript', '异步'])
print(f"文章 {post1.id} 的标签: {[tag.name for tag in post1.tags]}")
print(f"文章 {post2.id} 的标签: {[tag.name for tag in post2.tags]}\n")
# 查询用户及其文章
print("=" * 60)
print("6. 查询用户及其文章")
print("=" * 60)
user = user_dao.get_by_id(user1.id)
print(f"用户: {user.username}")
print(f"文章数量: {len(user.posts)}")
for post in user.posts:
print(f" - {post.title} ({post.status.value})")
print()
# 查询已发布的文章
print("=" * 60)
print("7. 查询已发布的文章")
print("=" * 60)
published_posts = post_dao.get_published()
for post in published_posts:
print(f"标题: {post.title}")
print(f"作者: {post.author.username}")
print(f"分类: {post.category.name if post.category else '无'}")
print(f"标签: {[tag.name for tag in post.tags]}")
print()
# 更新用户
print("=" * 60)
print("8. 更新用户信息")
print("=" * 60)
updated_user = user_dao.update(user1.id, age=26, bio='高级 Python 开发者')
print(f"更新后的用户: {updated_user}\n")
# 增加浏览量
print("=" * 60)
print("9. 增加文章浏览量")
print("=" * 60)
post_dao.increment_view(post1.id)
post = post_dao.get_by_id(post1.id)
print(f"文章 '{post.title}' 的浏览量: {post.view_count}\n")
# 统计查询
print("=" * 60)
print("10. 统计查询")
print("=" * 60)
from sqlalchemy import func
user_post_count = session.query(
User.username,
func.count(Post.id).label('post_count')
).join(Post).group_by(User.id).all()
for username, count in user_post_count:
print(f"{username}: {count} 篇文章")
finally:
session.close()
if __name__ == '__main__':
main()
最佳实践
1. 使用上下文管理器管理会话
from contextlib import contextmanager
@contextmanager
def get_db_session():
"""数据库会话上下文管理器"""
session = SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
# 使用
with get_db_session() as session:
user = User(username='test', email='test@example.com', password_hash='hash')
session.add(user)
2. 使用配置类管理数据库连接
import os
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
class DatabaseConfig:
"""数据库配置"""
DB_HOST = os.getenv('DB_HOST', 'localhost')
DB_PORT = os.getenv('DB_PORT', '3306')
DB_USER = os.getenv('DB_USER', 'testuser')
DB_PASSWORD = os.getenv('DB_PASSWORD', 'testpass123')
DB_NAME = os.getenv('DB_NAME', 'testdb')
@classmethod
def get_database_url(cls):
return (
f"mysql+pymysql://{cls.DB_USER}:{cls.DB_PASSWORD}"
f"@{cls.DB_HOST}:{cls.DB_PORT}/{cls.DB_NAME}?charset=utf8mb4"
)
@classmethod
def create_engine(cls, echo=False):
return create_engine(
cls.get_database_url(),
echo=echo,
pool_size=5,
max_overflow=10,
pool_pre_ping=True,
pool_recycle=3600,
)
3. 使用 Alembic 进行数据库迁移
# 安装 Alembic
pip install alembic
# 初始化 Alembic
alembic init alembic
# 创建迁移
alembic revision --autogenerate -m "创建用户表"
# 应用迁移
alembic upgrade head
# 回滚迁移
alembic downgrade -1
4. 性能优化建议
- 使用连接池:合理配置
pool_size和max_overflow - 预加载关系:使用
joinedload或selectinload避免 N+1 查询 - 批量操作:使用
bulk_insert_mappings进行批量插入 - 索引优化:为常用查询字段添加索引
- 查询优化:只查询需要的字段,使用
load_only()
5. 安全建议
- 参数化查询:始终使用参数化查询,防止 SQL 注入
- 密码加密:不要存储明文密码,使用哈希算法
- 权限控制:数据库用户只授予必要的权限
- 连接加密:生产环境使用 SSL 连接
总结
本文详细介绍了 SQLAlchemy ORM 框架的使用,包括:
- 核心概念:Engine、Session、Model 等
- 模型定义:如何定义模型和关系
- CRUD 操作:创建、读取、更新、删除
- 关系映射:一对多、多对多关系
- 查询操作:各种查询方法和技巧
- 事务处理:事务的使用和管理
- 测试数据:完整的 DDL 和测试数据
- 最佳实践:性能优化和安全建议
关键要点
- 使用 ORM 的优势:代码更简洁,类型安全,易于维护
- 关系映射:合理使用关系可以简化代码
- 查询优化:注意 N+1 查询问题,使用预加载
- 事务管理:正确处理事务,确保数据一致性
- 代码组织:使用 DAO 模式组织数据访问代码
下一步学习
- 学习 Alembic 进行数据库迁移
- 了解 SQLAlchemy Core 的使用
- 学习数据库性能优化
- 掌握 Flask-SQLAlchemy 或 Django ORM
希望这篇文章能帮助你掌握 SQLAlchemy 的使用!