数据库测试¶
信息
这些文档即将更新。🎉
当前版本假设 Pydantic v1 和小于 2.0 的 SQLAlchemy 版本。
新文档将包含 Pydantic v2,并在更新为也使用 Pydantic v2 后使用 SQLModel(它也基于 SQLAlchemy)。
您可以使用 使用覆盖测试依赖项 中相同的依赖项覆盖来更改数据库以进行测试。
您可能希望为测试设置一个不同的数据库,在测试后回滚数据,用一些测试数据预填充它等。
主要思想与您在上一章中看到的完全相同。
为 SQL 应用程序添加测试¶
让我们更新来自 SQL(关系型)数据库 的示例以使用测试数据库。
所有应用程序代码都相同,您可以返回该章节查看它是如何工作的。
这里唯一更改的是新的测试文件。
您的普通依赖项 get_db()
将返回一个数据库会话。
在测试中,您可以使用依赖项覆盖来返回您的自定义数据库会话,而不是通常使用的会话。
在此示例中,我们将仅为测试创建一个临时数据库。
文件结构¶
我们在 sql_app/tests/test_sql_app.py
中创建一个新文件。
因此,新的文件结构如下所示
.
└── sql_app
├── __init__.py
├── crud.py
├── database.py
├── main.py
├── models.py
├── schemas.py
└── tests
├── __init__.py
└── test_sql_app.py
创建新的数据库会话¶
首先,我们使用新数据库创建一个新的数据库会话。
我们将使用一个在测试期间持续存在的内存数据库,而不是本地文件 sql_app.db
。
但是其余的会话代码或多或少相同,我们只需复制它。
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from ..database import Base
from ..main import app, get_db
SQLALCHEMY_DATABASE_URL = "sqlite://"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
def test_create_user():
response = client.post(
"/users/",
json={"email": "deadpool@example.com", "password": "chimichangas4life"},
)
assert response.status_code == 200, response.text
data = response.json()
assert data["email"] == "deadpool@example.com"
assert "id" in data
user_id = data["id"]
response = client.get(f"/users/{user_id}")
assert response.status_code == 200, response.text
data = response.json()
assert data["email"] == "deadpool@example.com"
assert data["id"] == user_id
提示
您可以通过将该代码放在一个函数中并在 database.py
和 tests/test_sql_app.py
中使用它来减少代码重复。
为简单起见,并专注于特定的测试代码,我们只是复制它。
创建数据库¶
因为现在我们将在一个新文件中使用一个新的数据库,我们需要确保我们使用以下命令创建数据库
Base.metadata.create_all(bind=engine)
这通常在 main.py
中调用,但 main.py
中的行使用数据库文件 sql_app.db
,我们需要确保我们为测试创建 test.db
。
因此,我们在这里添加了该行,并使用新文件。
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from ..database import Base
from ..main import app, get_db
SQLALCHEMY_DATABASE_URL = "sqlite://"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
def test_create_user():
response = client.post(
"/users/",
json={"email": "deadpool@example.com", "password": "chimichangas4life"},
)
assert response.status_code == 200, response.text
data = response.json()
assert data["email"] == "deadpool@example.com"
assert "id" in data
user_id = data["id"]
response = client.get(f"/users/{user_id}")
assert response.status_code == 200, response.text
data = response.json()
assert data["email"] == "deadpool@example.com"
assert data["id"] == user_id
依赖项覆盖¶
现在我们创建依赖项覆盖并将其添加到我们应用程序的覆盖中。
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from ..database import Base
from ..main import app, get_db
SQLALCHEMY_DATABASE_URL = "sqlite://"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
def test_create_user():
response = client.post(
"/users/",
json={"email": "deadpool@example.com", "password": "chimichangas4life"},
)
assert response.status_code == 200, response.text
data = response.json()
assert data["email"] == "deadpool@example.com"
assert "id" in data
user_id = data["id"]
response = client.get(f"/users/{user_id}")
assert response.status_code == 200, response.text
data = response.json()
assert data["email"] == "deadpool@example.com"
assert data["id"] == user_id
提示
override_get_db()
的代码与 get_db()
的代码几乎完全相同,但在 override_get_db()
中,我们使用 TestingSessionLocal
用于测试数据库。
测试应用程序¶
然后我们可以像往常一样测试应用程序。
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from ..database import Base
from ..main import app, get_db
SQLALCHEMY_DATABASE_URL = "sqlite://"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
def test_create_user():
response = client.post(
"/users/",
json={"email": "deadpool@example.com", "password": "chimichangas4life"},
)
assert response.status_code == 200, response.text
data = response.json()
assert data["email"] == "deadpool@example.com"
assert "id" in data
user_id = data["id"]
response = client.get(f"/users/{user_id}")
assert response.status_code == 200, response.text
data = response.json()
assert data["email"] == "deadpool@example.com"
assert data["id"] == user_id
并且我们在测试期间对数据库进行的所有修改都将在 test.db
数据库中,而不是在主 sql_app.db
中。