FastAPI 后端开发口味
这是我自己在做 FastAPI 后端服务时反复用到的模式集合,来自我的 Claude Code skill。 不是"FastAPI 教程",而是我个人的口味:选择项目布局的判断、模型分层方式、返回包装、 日志/测试/部署的具体偏好。如果你跟我口味相近可以直接抄;不一样的地方就改成你的。
总则
- Python 最新版本 + FastAPI 最新语法
- 优先依赖注入
- 不过度工程:根据规模选合适的布局
- 不写
README.md(自己写) - 不安装环境(用 uv)
项目布局选择
根据项目规模选择,不要过度工程:
| 规模 | 布局 | 适用场景 |
|---|---|---|
| 单文件 | main.py |
脚本、demo、原型、5 个接口以内 |
| 轻量多文件 | 按职责拆文件,不按领域分包 | 小型服务、1-2 个业务领域 |
| 领域分包 | 按领域分文件夹 + 统一文件结构 | 多业务领域、团队协作的大项目 |
默认选单文件或轻量多文件。只有用户明确说"规划项目架构"或项目明显涉及多个独立业务领域时才用领域分包,详见下面的"项目布局"折叠块。
项目基础结构
from contextlib import asynccontextmanager
from fastapi import FastAPI
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动前
yield
# 关闭后
app = FastAPI(lifespan=lifespan)
统一返回包装
所有 API 返回用同一个泛型 Response[T]:
class Response[T](BaseModel):
code: int = 0
message: str = "success"
data: T | dict = {}
success: bool = True
路由使用:
模型分层
复杂模型分四层,用继承组合 —— 数据库表、输入参数、返回值彼此解耦:
from sqlmodel import SQLModel, Field
from pydantic import BaseModel
from uuid import uuid4, UUID
from datetime import datetime
# 1. Base:共享字段,不建表
class BaseMessage(SQLModel):
session_id: str
message: str
sender: SenderType
# 2. DB Table:继承 Base,加主键 / 索引 / 时间戳
class MessageTable(BaseMessage, table=True):
message_id: UUID = Field(default_factory=uuid4, primary_key=True)
created_at: datetime = Field(default_factory=datetime.now)
trace_id: str | None = None
# 3. Input:继承 Base + 混入其他 mixin
class ReceivedMessage(BaseMessage, LangFuseRecord, LanggraphConfig):
"""接收参数,多继承混入额外字段"""
pass
# 4. Output:独立定义,与 DB 结构不同
class ReplyData(BaseModel):
name: str
detail: str
is_final: bool = False
turn_to_human: bool = False
工具偏好
- 日志库:loguru,报错用
logger.exception(e),只输出到控制台不写文件 - ORM:SQLModel(FastAPI 作者出的,本质是 SQLAlchemy + Pydantic 的融合,最优雅)
- 环境管理:uv
按场景的完整 reference
下面 14 个折叠块对应 14 种典型场景,每一块都是我实际在用的完整代码片段。 默认全部收起,只展开你当下需要的。
项目布局:轻量多文件 / 领域分包
FastAPI 项目布局详解
轻量多文件
适用于小型服务、10 个接口左右、1-2 个业务领域:
my_app/
├── __init__.py
├── main.py # FastAPI app 入口 + 路由注册
├── router.py # 路由(接口多时可拆成 router_xxx.py)
├── models.py # SQLModel 数据库模型
├── schemas.py # Pydantic 入参/出参
├── service.py # 业务逻辑
├── deps.py # 依赖注入
└── config.py # 配置
作为一个 Python 包组织,包内用相对导入 from .models import Xxx,外部通过 pip install -e . 安装后用绝对导入。
main.py 示例:
from contextlib import asynccontextmanager
from fastapi import FastAPI
from .router import router
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
app = FastAPI(lifespan=lifespan)
app.include_router(router)
领域分包
适用于多业务领域、团队协作、长期维护的大项目:
src/
├── main.py # 入口:创建 app,注册所有 router
├── config.py # 全局配置(DATABASE_URL, REDIS_URL 等)
├── database.py # 数据库连接、Session 工厂
├── models.py # 全局基础模型(Base, 时间戳 mixin)
├── exceptions.py # 全局异常处理器
├── pagination.py # 通用分页
│
├── auth/ # ─── 认证领域 ───
│ ├── __init__.py
│ ├── router.py # 端点定义
│ ├── schemas.py # Pydantic 入参/出参
│ ├── models.py # 数据库模型
│ ├── service.py # 业务逻辑
│ ├── dependencies.py # 依赖注入(如 get_current_user)
│ ├── constants.py # 常量和错误码
│ ├── config.py # 模块专属环境变量
│ ├── utils.py # 工具函数(如 hash_password)
│ └── exceptions.py # 模块专属异常(如 InvalidCredentials)
│
├── posts/ # ─── 帖子领域(同样的 9 个文件)───
│ └── ...
└── users/ # ─── 用户领域(同样的 9 个文件)───
└── ...
9 个文件各自的职责
| 文件 | 职责 | 示例 |
|---|---|---|
| router.py | 定义 HTTP 端点,只做接/返,不写业务 | @router.get("/{post_id}") |
| schemas.py | Pydantic 模型,定义入参出参的形状 | PostCreate, PostResponse |
| models.py | SQLModel/SQLAlchemy 数据库表定义 | class Post(Base, table=True) |
| service.py | 业务逻辑,查库、计算、组合数据 | create(), get_by_id() |
| dependencies.py | 依赖注入,做请求级校验 | valid_post_id(), valid_owned_post() |
| constants.py | 常量、枚举、错误码 | ErrorCode.POST_NOT_FOUND |
| config.py | 模块专属的 BaseSettings | PostsConfig(MAX_POSTS_PER_USER=100) |
| utils.py | 非业务工具函数 | slugify(), truncate() |
| exceptions.py | 模块专属异常类 | PostNotFound(HTTPException) |
关键规则
依赖方向单向,绝不反向:
schemas / constants / config / utils / exceptions 是底层工具,任何层都可以用,但它们不依赖上面的层。跨模块用绝对导入 + 别名:
main.py 只做注册:
异步数据库(SQLModel)
异步数据库 (SQLModel)
需要和用户确认使用哪种数据库以及是使用同步还是异步数据库
依赖
只需在最后提醒需要安装什么,不要自己安装,由用户自行安装
创建异步引擎
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
### PostgreSQL
DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/db"
async_engine = create_async_engine(DATABASE_URL, echo=True)
### SQLite
DATABASE_URL = "sqlite+aiosqlite:///./database.db"
async_engine = create_async_engine(
DATABASE_URL,
echo=True,
connect_args={"check_same_thread": False}, # SQLite 必需
)
async def get_session():
async with AsyncSession(async_engine, expire_on_commit=False) as session:
yield session
与 lifespan 集成
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends
@asynccontextmanager
async def lifespan(app: FastAPI):
# 异步建表
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
yield
await async_engine.dispose()
app = FastAPI(lifespan=lifespan)
### 使用依赖注入
@app.get("/users/{user_id}")
async def get_user(user_id: int, session: AsyncSession = Depends(get_session)):
user = await session.get(User, user_id)
return user
同步数据库 (SQLModel)
同步数据库已经比较成熟,自由发挥即可
同步 + 异步双引擎模式
当项目需要同时使用异步引擎(主业务)和同步引擎(如 sqlite-vec 向量库)时,需要注意:
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import create_engine, Session
### 异步引擎 - 主业务数据库
engine = create_async_engine("sqlite+aiosqlite:///volume/dbs/chat.db", echo=False)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
### 同步引擎 - 向量库等不支持异步的场景
engine_faq = create_engine("sqlite:///volume/dbs/faq.db")
### 分别注入
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session(expire_on_commit=False) as session:
yield session
def get_session_faq() -> Session:
with Session(engine_faq) as session:
yield session
建表时必须用 tables= 参数指定具体表,防止在错误的数据库中建表:
RAG / 向量检索(sqlite-vec / pgvector)
RAG 向量检索
提供两套方案进行选择,需要主动询问用户使用哪一种,一种是使用sqlite,一种是使用pgvector,默认选择sqlite
方案一:SQLite + sqlite-vec(推荐轻量部署)
依赖
数据库引擎设置
FAQ 向量库使用同步引擎(sqlite-vec 不支持异步),与主业务的异步引擎分开:
from sqlmodel import create_engine, Session
from sqlite_vec_sqlalchemy import enable_sqlite_vec
### FAQ 向量库 - 同步引擎
engine_faq = create_engine("sqlite:///volume/dbs/faq.db")
enable_sqlite_vec(engine_faq) # 必须调用,加载 sqlite-vec 扩展
### 依赖注入
def get_session_faq() -> Session:
with Session(engine_faq) as session:
yield session
模型定义
采用继承方式,分离 Base / DB Table / 查询结果:
from sqlmodel import SQLModel, Field, Column, JSON
from sqlite_vec_sqlalchemy import Vector
length = 3072 # embedding 维度,根据模型调整
### 1. Base:共享业务字段
class BaseFAQ(SQLModel):
id: int = Field(primary_key=True)
question: str
answer: str
image_urls: list[str] = Field(default=[], sa_column=Column(JSON))
### 2. DB Table:带向量列,用于建表和写入
class FAQ_CREATE(BaseFAQ, table=True):
__tablename__ = 'faq'
embedding1: list[float] = Field(sa_column=Column(Vector(length))) # 问题向量
embedding2: list[float] = Field(sa_column=Column(Vector(length))) # 答案向量
### 3. 查询结果:带距离字段,不建表
class FAQ_RES(BaseFAQ):
distance: float | None = None
@classmethod
def model_validate2(cls, data: tuple):
faq, distance = data
a = cls.model_validate(faq)
a.distance = distance
return a
### 4. 未命中记录表(可选)
class FAQ_NO_ANSWER(SQLModel, table=True):
__tablename__ = 'faq_no_answer'
id: int = Field(primary_key=True)
question: str
created_at: datetime = Field(default_factory=datetime.now)
created_date: date = Field(default_factory=lambda: datetime.now().date(), index=True)
建表
from sqlmodel import SQLModel
SQLModel.metadata.create_all(engine_faq, tables=[FAQ_CREATE.__table__, FAQ_NO_ANSWER.__table__])
注意:如果同时存在异步引擎的 create_all,需要用 tables= 参数指定具体的表,避免在错误的数据库中建表。
向量检索
使用 vec_distance_cosine 计算距离(值越小越相似):
from sqlite_vec_sqlalchemy import vec_distance_cosine
from sqlmodel import select, Session
def search_faq(question_embedding: list[float], top_k: int = 8) -> list[FAQ_RES]:
with Session(engine_faq, expire_on_commit=False) as session:
# 搜问题向量
dist1 = vec_distance_cosine(FAQ_CREATE.embedding1, question_embedding).label("distance")
stmt1 = select(FAQ_CREATE, dist1).order_by(dist1).limit(top_k)
res1 = session.exec(stmt1).all()
# 搜答案向量
dist2 = vec_distance_cosine(FAQ_CREATE.embedding2, question_embedding).label("distance")
stmt2 = select(FAQ_CREATE, dist2).order_by(dist2).limit(top_k)
res2 = session.exec(stmt2).all()
# 合并去重
seen = set()
results = []
for item in [FAQ_RES.model_validate2(i) for i in res1 + res2]:
if item.id not in seen:
seen.add(item.id)
results.append(item)
return results
FAQ 写入(单条)
embedding 生成 + 写入数据库:
from langchain_openai import OpenAIEmbeddings
model_embedding = OpenAIEmbeddings(model='text-embedding-3-large')
async def create_faq(q: str, a: str, image_urls: list[str] = []) -> FAQ_CREATE:
faq = FAQ_CREATE(question=q, answer=a, image_urls=image_urls)
embeddings = await model_embedding.aembed_documents([faq.question, faq.answer])
faq.embedding1 = embeddings[0]
faq.embedding2 = embeddings[1]
return faq
FAQ 批量同步(从外部数据源)
增量同步模式:对比新旧数据,只增删改变化部分,避免全量重建。
核心思路:以 (question, answer) 作为唯一键,比较新旧数据集,分为三种变化:
- 删除:旧有新无 → 从数据库删除
- 新增:新有旧无 → 生成 embedding 后写入
- 更新:键相同但附属字段(如 image_urls)变化 → 只更新字段,无需重新生成 embedding
from langchain_core.runnables import RunnableLambda
from tqdm import tqdm
### 用 RunnableLambda 包装,便于 abatch 并行调用
@RunnableLambda
async def add_faq(foo):
faq = FAQ_CREATE(question=foo['q'], answer=foo['a'], image_urls=foo.get('image_urls', []))
embeddings = await model_embedding.aembed_documents([faq.question, faq.answer])
faq.embedding1 = embeddings[0]
faq.embedding2 = embeddings[1]
return faq
async def sync_faq(session: Session, new_data: list[dict]):
"""
new_data: [{'q': str, 'a': str, 'image_urls': list[str]}, ...]
"""
import pandas as pd
# 1. 读取现有数据
df_old = pd.read_sql('faq', engine_faq)
# 2. 构建新旧映射 (q, a) -> image_urls
now_values = {(d['q'], d['a']): d.get('image_urls', []) for d in new_data}
old_values = {(row['question'], row['answer']): row.get('image_urls', [])
for _, row in df_old.iterrows()}
# 3. 删除:旧有新无
delete_ids = [
row['id'] for _, row in df_old.iterrows()
if (row['question'], row['answer']) not in now_values
]
if delete_ids:
logger.debug(f'删除 {len(delete_ids)} 条FAQ')
result = session.exec(select(FAQ_CREATE).where(FAQ_CREATE.id.in_(delete_ids)))
for faq in result.all():
session.delete(faq)
session.commit()
# 4. 新增:新有旧无
add_list = [d for d in new_data if (d['q'], d['a']) not in old_values]
# 5. 批量写入(每10条一批,abatch 并行生成 embedding)
logger.debug(f'添加 {len(add_list)} 条FAQ')
for i in tqdm(range(0, len(add_list), 10)):
batch = add_list[i:i+10]
faqs = await add_faq.abatch(batch)
session.add_all(faqs)
session.commit()
# 6. 更新:键相同但附属字段变化(无需重新 embedding)
update_list = [
{'q': q, 'a': a, 'image_urls': urls}
for (q, a), urls in now_values.items()
if (q, a) in old_values and sorted(urls) != sorted(old_values[(q, a)])
]
if update_list:
logger.debug(f'更新 {len(update_list)} 条FAQ的附属字段')
from sqlmodel import and_
for item in tqdm(update_list):
result = session.exec(
select(FAQ_CREATE).where(
and_(FAQ_CREATE.question == item['q'], FAQ_CREATE.answer == item['a'])
)
)
faq = result.first()
if faq:
faq.image_urls = item['image_urls']
session.add(faq)
session.commit()
将同步逻辑挂载为 FastAPI 端点:
@router.put("/sync_faq")
async def sync_faq_endpoint(session: Annotated[Session, Depends(get_session_faq)]):
try:
new_data = load_from_external_source() # 从外部数据源读取
await sync_faq(session, new_data)
return {'status': 'ok'}
except Exception as e:
logger.exception(e)
# 可选:发送异常通知(钉钉/飞书/Slack 等)
raise
FAQ 匹配:双路粗召 + LLM 精排
完整的 FAQ 匹配流程分三步:
1. 双路粗召:分别用 embedding1(问题向量)和 embedding2(答案向量)各召回 top_k 条,合并去重
2. LLM 精排:将粗召结果 + 对话历史交给大模型,用 JSON mode 输出最相关的 FAQ 编号
3. 结果输出:根据编号过滤出最终 FAQ,未命中则记录到 faq_no_answer 表
Rerank Prompt(单条/多条两种模式)
from jinja2 import Template
### 单条模式:只返回最贴近的一条
rerank_system_single = r"""
你是一个专业助手,正在筛选能解答用户问题的FAQ。
请使用json输出FAQ的id,不要输出其他信息,格式为{"faq_id": int},
如果有多条能解答用户提问的FAQ也只输出最贴近的一条的id,
如果没有能回答提问的FAQ,请输出{"faq_id": -1}。
筛选必须严格,禁止答非所问,禁止货不对板(例如用户问a产品,但是回答的是b产品)
"""
### 多条模式:返回所有能回答的
rerank_system_multi = r"""
你是一个专业助手,正在筛选能解答用户问题的FAQ。
请使用json输出FAQ的id,不要输出其他信息,格式为{"faq_id": list[int]},
如果有多条能解答用户提问的FAQ就要全部输出,
如果没有能回答提问的FAQ,请输出{"faq_id": []}。
筛选必须严格,禁止答非所问,禁止货不对板(例如用户问a产品,但是回答的是b产品)
"""
### 粗召结果模板
rerank_human_template = Template(r"""
粗召到的FAQ有:
{% for faq in faqs %}
faq_id: {{ faq.id }}:
问题: `{{ faq.question }}`
回答: `{{ faq.answer }}`
---
{% endfor %}
""")
完整匹配函数
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, BaseMessage
async def query_faq(question: str, history: list[BaseMessage], model, count: int = 1) -> list[FAQ_RES]:
"""
FAQ 匹配主函数
Args:
question: 用户问题
history: 对话历史(用于给 LLM 更多上下文)
model: 用于精排的 LLM
count: 最多返回几条(1=单条模式,>1=多条模式)
Returns:
匹配到的 FAQ 列表,空列表表示未命中
"""
# ── 1. 双路粗召 ──
emb = model_embedding.embed_query(question)
candidates = search_faq(emb, top_k=8) # 复用前面定义的 search_faq
logger.debug(f'粗召到 {len(candidates)} 条FAQ')
# ── 2. LLM 精排 ──
rendered = rerank_human_template.render(faqs=[i.model_dump() for i in candidates])
# 过滤对话历史:只保留 Human/AI 文本消息,排除 tool_calls
filtered_history = [
m for m in history
if isinstance(m, (HumanMessage, AIMessage))
and not (isinstance(m, AIMessage) and m.tool_calls)
]
system_prompt = rerank_system_single if count == 1 else rerank_system_multi
messages = [
SystemMessage(content=system_prompt),
*filtered_history,
HumanMessage(content=rendered),
]
result = await model.with_structured_output(method='json_mode').ainvoke(messages)
faq_id = result['faq_id']
# 统一为列表
faq_id_list = [faq_id] if isinstance(faq_id, int) else faq_id
# ── 3. 过滤输出 ──
matched = [f for f in candidates if f.id in faq_id_list][:count]
if not matched:
# 记录未命中,便于后续分析 FAQ 覆盖率
with Session(engine_faq, expire_on_commit=False) as session:
session.add(FAQ_NO_ANSWER(question=question))
session.commit()
return matched
从对话历史中提取问题(可选)
当用户消息不是直接的问题(如 "帮我看看这个"),可以先用小模型从对话历史中提取出真正的问题:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from operator import itemgetter
prompt_extract_question = ChatPromptTemplate.from_messages([
("system", "用一句话总结用户现在的问题。不要思考太多, 思考控制在50字以内。请使用json输出,格式为{\"question\": str}"),
MessagesPlaceholder(variable_name='history'),
("human", "请提取出用户的问题,输出json"),
], template_format='jinja2')
chain_extract_question = (
prompt_extract_question
| model_small.with_structured_output(method='json_mode')
| itemgetter('question')
)
### 使用:question = await chain_extract_question.ainvoke({'history': messages})
方案二:PostgreSQL + pgvector
依赖
要点
- 使用异步引擎
postgresql+asyncpg:// - Vector 类型来自
pgvector.sqlalchemy - 支持 IVFFlat / HNSW 索引,适合大规模数据
- 其余模型定义和查询逻辑与 sqlite-vec 类似,主要区别是距离函数用法
如果用户选择 pgvector,需根据具体场景调整索引策略和连接池配置。
MCP 开发
需要使用最新的fastmcp 3 来写服务
有几个要点: 1. 使用中间件或者装饰器来实现认证和日志记录(下面的例子中只是logger输出出来,但实际项目里需要询问用户是使用什么方式来记录) 2. 复用核心逻辑函数,同时支持普通http请求和mcp请求
"""FastMCP + FastAPI 混合服务(JWT 认证版)
使用 Ed25519 JWT Token 进行用户身份验证。
环境变量:
JWT_ED25519_PUBLIC_KEY: Base64 编码的 Ed25519 公钥
使用方法:
export JWT_ED25519_PUBLIC_KEY="你的Base64公钥"
uv run main.py
"""
import os
import time
from fastapi import FastAPI, Depends, HTTPException, Request, Response
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastmcp import FastMCP
from fastmcp.server.auth.auth import TokenVerifier
from fastmcp.server.middleware import Middleware, MiddlewareContext
from fastmcp.server.dependencies import get_access_token
from fastmcp.dependencies import CurrentAccessToken
from fastmcp.server.auth import AccessToken
from loguru import logger
from jwt_verifier import (
JWTTokenVerifier,
TokenExpiredError,
TokenFormatError,
TokenInvalidError,
VerifiedUser,
)
### ============================================================
### 配置
### ============================================================
PUBLIC_KEY = os.environ.get("JWT_ED25519_PUBLIC_KEY", "")
if not PUBLIC_KEY:
raise RuntimeError("请设置 JWT_ED25519_PUBLIC_KEY 环境变量")
verifier = JWTTokenVerifier(PUBLIC_KEY)
### ============================================================
### 自定义 JWT 验证器(用于 FastMCP 认证)
### ============================================================
class JWTMCPVerifier(TokenVerifier):
"""
JWT 验证器,适配 FastMCP 的 TokenVerifier 协议。
接收 Bearer Token(JWT),验证签名后返回 AccessToken。
"""
def __init__(self, jwt_verifier: JWTTokenVerifier, required_scopes: list[str] | None = None):
super().__init__(required_scopes=required_scopes or ["tools"])
self.jwt_verifier = jwt_verifier
async def verify_token(self, token: str) -> AccessToken | None:
"""
验证 JWT Token,返回 AccessToken 对象。
Args:
token: JWT Token 字符串
Returns:
AccessToken 对象(验证成功)或 None(验证失败)
"""
try:
user = self.jwt_verifier.verify_token(token)
logger.info(f"✅ JWT 验证成功!用户: {user.name} ({user.email})")
return AccessToken(
token=token,
client_id=user.email or user.user_id,
scopes=self.required_scopes,
expires_at=None,
claims=user.claims,
)
except TokenExpiredError:
logger.warning("❌ JWT Token 已过期")
return None
except TokenInvalidError:
logger.warning("❌ JWT Token 签名无效")
return None
except TokenFormatError as e:
logger.warning(f"❌ JWT Token 格式错误: {e}")
return None
### ============================================================
### MCP 日志中间件(记录谁调用了什么工具、参数、耗时)
### ============================================================
class AuditLoggingMiddleware(Middleware):
"""
MCP 审计日志中间件。
拦截所有 tool 调用,记录:
- 调用者身份(client_id / email)
- 工具名称
- 调用参数
- 执行耗时
- 执行结果(成功/失败)
"""
async def on_call_tool(self, context: MiddlewareContext, call_next):
tool_name = context.message.name
arguments = context.message.arguments
# 尝试获取调用者身份
caller = "anonymous"
try:
token = get_access_token()
if token:
caller = token.client_id or "unknown"
except Exception:
pass
logger.debug(
"[MCP Audit] 🔧 tool={tool} | caller={caller} | args={args}",
tool=tool_name,
caller=caller,
args=arguments,
)
start = time.perf_counter()
try:
result = await call_next(context)
elapsed = (time.perf_counter() - start) * 1000
logger.debug(
"[MCP Audit] ✅ tool={tool} | caller={caller} | elapsed={elapsed:.1f}ms",
tool=tool_name,
caller=caller,
elapsed=elapsed,
)
return result
except Exception as exc:
elapsed = (time.perf_counter() - start) * 1000
logger.debug(
"[MCP Audit] ❌ tool={tool} | caller={caller} | elapsed={elapsed:.1f}ms | error={error}",
tool=tool_name,
caller=caller,
elapsed=elapsed,
error=str(exc),
)
raise
### ============================================================
### 核心业务逻辑(复用)
### ============================================================
def core_add_logic(a: int, b: int) -> int:
return a + b
### ============================================================
### 1. 初始化 FastMCP(带 JWT 认证)
### ============================================================
mcp_auth = JWTMCPVerifier(
jwt_verifier=verifier,
required_scopes=["tools"],
)
mcp = FastMCP("Demo 🚀", auth=mcp_auth)
### 注册 MCP 日志中间件
mcp.add_middleware(AuditLoggingMiddleware())
### ============================================================
### 2. 注册 MCP 工具(通过 CurrentAccessToken 获取用户信息)
### ============================================================
@mcp.tool()
async def add(
a: int,
b: int,
token: AccessToken = CurrentAccessToken(),
) -> str:
"""Add two numbers (Available for LLM)
两数相加,返回结果和调用者信息。
"""
user_email = token.client_id
result = core_add_logic(a, b)
return f"{a} + {b} = {result} (called by {user_email})"
@mcp.tool()
async def whoami(
token: AccessToken = CurrentAccessToken(),
) -> dict:
"""查看当前认证用户的信息"""
logger.info(f"token: {token}")
return {
"email": token.client_id,
"scopes": token.scopes,
"claims": token.claims,
}
### ============================================================
### 3. 初始化 FastAPI + 挂载 MCP
### ============================================================
mcp_app = mcp.http_app(path="/")
app = FastAPI(
title="Hybrid API & MCP Server (JWT Auth)",
lifespan=mcp_app.lifespan,
)
### 挂载 MCP 到 /mcp 路径
app.mount("/mcp", mcp_app)
### ============================================================
### FastAPI 日志中间件(记录 REST 接口调用)
### ============================================================
@app.middleware("http")
async def fastapi_audit_logging(request: Request, call_next):
"""
FastAPI 审计日志中间件。
记录:
- 请求方法 + 路径 + 查询参数
- 调用者身份(从 Authorization header 解析 JWT)
- 响应状态码
- 执行耗时
"""
caller = "anonymous"
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
token_str = auth_header[7:]
try:
user = verifier.verify_token(token_str)
caller = user.email or user.user_id
except Exception:
caller = "invalid-token"
method = request.method
path = request.url.path
query = str(request.query_params) if request.query_params else ""
logger.debug(
"[REST Audit] → {method} {path} | caller={caller} | query={query}",
method=method,
path=path,
caller=caller,
query=query,
)
start = time.perf_counter()
response: Response = await call_next(request)
elapsed = (time.perf_counter() - start) * 1000
logger.debug(
"[REST Audit] ← {method} {path} | caller={caller} | status={status} | elapsed={elapsed:.1f}ms",
method=method,
path=path,
caller=caller,
status=response.status_code,
elapsed=elapsed,
)
return response
### ============================================================
### 4. FastAPI REST 端点认证(使用 JWT 验证)
### ============================================================
bearer_scheme = HTTPBearer(
scheme_name="JWT Bearer Token",
description="输入 JWT Token(不需要加 Bearer 前缀)",
)
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
) -> VerifiedUser:
"""从 Authorization: Bearer <token> 提取并验证用户身份"""
token = credentials.credentials
try:
return verifier.verify_token(token)
except TokenExpiredError:
raise HTTPException(status_code=401, detail="Token expired")
except TokenInvalidError:
raise HTTPException(status_code=401, detail="Invalid token")
except TokenFormatError as e:
raise HTTPException(status_code=400, detail=f"Malformed token: {e}")
### ============================================================
### 5. 注册 FastAPI REST 接口(通过 Depends 获取用户信息)
### ============================================================
@app.get("/api/add")
async def add_api(a: int, b: int, user: VerifiedUser = Depends(get_current_user)):
"""Standard REST API (Available for Web/Mobile)"""
result = core_add_logic(a, b)
return {
"result": result,
"user": {
"user_id": user.user_id,
"email": user.email,
"name": user.name,
},
"source": "fastapi",
}
@app.get("/api/whoami")
async def whoami_api(user: VerifiedUser = Depends(get_current_user)):
"""查看当前认证用户的信息(REST 版)"""
return {
"user": {
"user_id": user.user_id,
"email": user.email,
"name": user.name,
"claims": user.claims,
},
"source": "fastapi",
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
### jwt_verifier.py
"""JWT Ed25519 Token 验证器
使用 Ed25519 公钥验证 JWT Token,提取用户信息。
依赖:
uv add PyJWT cryptography
环境变量:
JWT_ED25519_PUBLIC_KEY: Base64 编码的 Ed25519 公钥
"""
import base64
from dataclasses import dataclass, field
from typing import Any
import jwt
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
from cryptography.hazmat.primitives.serialization import load_pem_public_key
import os
### ============================================================
### 异常定义
### ============================================================
class TokenExpiredError(Exception):
"""Token 已过期"""
pass
class TokenInvalidError(Exception):
"""Token 签名无效或不可信"""
pass
class TokenFormatError(Exception):
"""Token 格式错误,无法解析"""
pass
### ============================================================
### 验证后的用户信息
### ============================================================
@dataclass
class VerifiedUser:
"""JWT 验证成功后返回的用户信息"""
user_id: str
email: str
name: str
claims: dict[str, Any] = field(default_factory=dict)
### ============================================================
### 辅助函数
### ============================================================
def _load_from_pem(pem_bytes: bytes) -> Ed25519PublicKey:
"""从 PEM 字节加载 Ed25519 公钥"""
key = load_pem_public_key(pem_bytes)
if not isinstance(key, Ed25519PublicKey):
raise ValueError(f"PEM 中包含的不是 Ed25519 公钥,而是 {type(key).__name__}")
return key
### ============================================================
### JWT Token 验证器
### ============================================================
class JWTTokenVerifier:
"""使用 Ed25519 公钥验证 JWT Token
支持三种公钥格式:
1. Base64(PEM) — PEM 字符串的 Base64 编码(常用于环境变量)
2. 原始 PEM — -----BEGIN PUBLIC KEY----- 开头的字符串
3. Base64(raw) — 32 字节原始公钥的 Base64 编码
Usage:
verifier = JWTTokenVerifier("Base64EncodedPublicKey...")
user = verifier.verify_token("eyJhbGciOi...")
"""
def __init__(self, public_key_b64: str):
self._public_key = self._load_public_key(public_key_b64.strip())
@staticmethod
def _load_public_key(key_input: str) -> Ed25519PublicKey:
"""智能加载 Ed25519 公钥,自动检测格式"""
# 情况 1:直接就是 PEM 明文
if key_input.startswith("-----BEGIN"):
return _load_from_pem(key_input.encode())
# Base64 解码
try:
key_bytes = base64.b64decode(key_input)
except Exception as e:
raise ValueError(f"无法 Base64 解码公钥: {e}")
# 情况 2:解码后是 PEM 字符串(Base64 包裹的 PEM)
try:
pem_text = key_bytes.decode("utf-8")
if pem_text.strip().startswith("-----BEGIN"):
return _load_from_pem(pem_text.encode())
except UnicodeDecodeError:
pass # 不是文本,继续尝试原始字节
# 情况 3:解码后正好是 32 字节原始公钥
if len(key_bytes) == 32:
try:
return Ed25519PublicKey.from_public_bytes(key_bytes)
except Exception as e:
raise ValueError(f"无法从原始字节加载公钥: {e}")
raise ValueError(
f"无法识别公钥格式(解码后 {len(key_bytes)} 字节)。"
"支持: PEM 明文 / Base64(PEM) / Base64(32字节原始公钥)"
)
def verify_token(self, token: str) -> VerifiedUser:
"""验证 JWT Token 并返回用户信息
Args:
token: JWT Token 字符串
Returns:
VerifiedUser 对象
Raises:
TokenExpiredError: Token 已过期
TokenInvalidError: Token 签名无效
TokenFormatError: Token 格式错误
"""
try:
payload = jwt.decode(
token,
self._public_key,
algorithms=["EdDSA"],
)
except jwt.ExpiredSignatureError:
if not os.getenv("DEBUG"):
raise TokenExpiredError("Token 已过期")
# DEBUG 模式下允许过期 Token,重新解码但跳过过期验证
payload = jwt.decode(
token,
self._public_key,
algorithms=["EdDSA"],
options={"verify_exp": False},
)
except jwt.InvalidTokenError as e:
raise TokenInvalidError(f"Token 无效: {e}")
except Exception as e:
raise TokenFormatError(f"Token 格式错误: {e}")
return VerifiedUser(
user_id=payload.get("sub", ""),
email=payload.get("email", ""),
name=payload.get("name", ""),
claims=payload,
)
优雅终止 + 流式中断 + 乐观锁(连发消息)
优雅终止
在生产环境中,服务需要在收到终止信号时优雅地完成正在处理的请求,而不是直接中断。
GracefulRunner 模式
核心思路:捕获 SIGINT/SIGTERM → 设置标志位 → 转发信号给 uvicorn → 业务代码轮询标志位来中断长任务。
from loguru import logger
import signal
import os
import time
class GracefulRunner:
def __init__(self):
self._keep_running = True
self.original_sigint_handler = None
self.original_sigterm_handler = None
def __bool__(self):
"""支持 `if not runner:` 这种优雅写法"""
return self._keep_running
def setup_signal_handlers(self):
self.original_sigint_handler = signal.signal(signal.SIGINT, self._handle_signal)
self.original_sigterm_handler = signal.signal(signal.SIGTERM, self._handle_signal)
def _handle_signal(self, signum, frame):
if not self._keep_running:
return
logger.info(f"收到信号 {signum},标记停止运行...")
self._keep_running = False
import random
time.sleep(random.random()) # 避免多worker同时抢占文件描述符
# 转发信号给 uvicorn,让它正常走关闭流程
if signum == signal.SIGINT and self.original_sigint_handler:
if callable(self.original_sigint_handler):
self.original_sigint_handler(signum, frame)
else:
signal.signal(signal.SIGINT, signal.SIG_DFL)
os.kill(os.getpid(), signal.SIGINT)
elif signum == signal.SIGTERM and self.original_sigterm_handler:
if callable(self.original_sigterm_handler):
self.original_sigterm_handler(signum, frame)
else:
signal.signal(signal.SIGTERM, signal.SIG_DFL)
os.kill(os.getpid(), signal.SIGTERM)
runner = GracefulRunner()
在 lifespan 中注册
@asynccontextmanager
async def lifespan(app: FastAPI):
runner.setup_signal_handlers()
yield
logger.success('服务已优雅关闭')
在流式响应/长任务中检查
关键点:任何长时间运行的循环都应该检查 runner 状态。
from .utils.runner import runner
### 简单用法:在循环中检查
async for chunk in stream:
if not runner:
logger.warning("收到终止信号,正在优雅退出...")
break
yield chunk
StreamController:流式响应的中断控制
对于流式 API,通常需要检查多种中断条件(不仅是后端终止,还有用户发新消息等)。可以封装成控制器:
from dataclasses import dataclass, field
from enum import Enum, auto
class StopReason(Enum):
BACKEND_SHUTDOWN = auto() # 后端终止信号
NEW_MESSAGE = auto() # 用户发送新消息
@dataclass
class StopSignal:
reason: StopReason
data: str # 要返回给客户端的数据
@dataclass
class StreamController:
session: AsyncSession
session_id: str
lock: Any
check_interval: float = 5.0
_last_check_time: float = field(default=0.0, init=False)
async def check_should_stop(self) -> StopSignal | None:
# 1. 后端终止 — 立即检查
if not runner:
return StopSignal(reason=StopReason.BACKEND_SHUTDOWN, data=...)
# 2. 其他条件 — 按间隔检查(避免频繁查库)
now = time.time()
if now - self._last_check_time < self.check_interval:
return None
self._last_check_time = now
# 例:检查是否有新消息覆盖了当前回复
await self.session.refresh(self._ticket)
if self._ticket.lock != self.lock:
return StopSignal(reason=StopReason.NEW_MESSAGE, data=...)
return None
使用方式:
controller = StreamController(session, session_id, lock)
async for chunk in stream:
if stop := await controller.check_should_stop():
yield stop.data
await stream.aclose()
return
yield process(chunk)
乐观锁:允许用户连发消息
在对话场景中,用户可以连续发送多条消息。新消息到达时,正在生成的旧回复应该被中断,让 AI 基于完整上下文重新回复。通过乐观锁实现这个机制。
核心设计
在 Session 表上放一个 lock 字段(UUID),每次收到新消息时更新它。正在运行的流式任务定期检查这个值是否变了,变了就说明有新消息进来,应该中断当前生成。
class SessionTable(SQLModel, table=True):
session_id: str = Field(primary_key=True)
lock: UUID | None = Field(default=None) # 乐观锁:每条新消息刷新
# ...
新消息到达时:写入消息 + 刷新锁
关键:消息写入和锁更新必须在同一次 commit 中,保证原子性。
@router.post("/send")
async def send_message(request: ReceivedMessage, session: ...):
lock = uuid4() # 生成新锁
# 获取 session,更新锁
ticket = await session.get(SessionTable, session_id)
ticket.lock = lock
session.add(ticket)
# 同时写入新消息
db_message = MessageTable.model_validate(request)
session.add(db_message)
# 一次 commit,保证锁和消息同时生效
await session.commit()
# 后续用这个 lock 创建 StreamController
controller = StreamController(session, session_id, lock)
流式生成中:检查锁是否被覆盖
StreamController 定期从数据库 refresh session 记录,对比 lock 值:
### 在 StreamController.check_should_stop() 中:
await self.session.refresh(self._ticket)
if self._ticket.lock != self.lock:
# lock 变了 → 有新消息进来了 → 中断当前生成
return StopSignal(reason=StopReason.NEW_MESSAGE, data=...)
连发消息时的消息合并
用户连发多条消息后,最后一个请求拿到锁,它需要把之前所有未回复的消息一起作为输入:
### 查询该 session 下所有有效消息
stmt = (
select(MessageTable)
.where(
MessageTable.session_id == session_id,
MessageTable.send_status.in_([
MessageStatus.SUCCESS,
MessageStatus.PROCESSING,
MessageStatus.INITIALIZING
]),
)
.order_by(MessageTable.created_at)
)
db_messages = (await session.exec(stmt)).all()
### 从最后一个 checkpoint 开始,收集所有新消息作为输入
new_messages = get_new_messages(db_messages)
时序保护
极端情况下,新消息可能比老消息先写入数据库。检测到这种情况时不报错,而是容忍——因为锁机制保证了只有最新的请求会继续执行:
if db_messages[-1].message_id != db_message.message_id:
logger.warning('新消息比老消息更先写入数据库')
# 不影响正确性,后排部队转作先头部队
整体流程图
定时任务(FastScheduler)
定时任务
推荐方案:FastScheduler
使用 fastscheduler,自带 FastAPI 控制面板,支持 cron 表达式、重试、持久化等。
安装
与 FastAPI 集成
from fastapi import FastAPI
from fastscheduler import FastScheduler
from fastscheduler.fastapi_integration import create_scheduler_routes
app = FastAPI()
scheduler = FastScheduler(quiet=True)
### 挂载控制面板到 /scheduler/
app.include_router(create_scheduler_routes(scheduler))
@scheduler.every(30).seconds
async def background_task():
print("Background work")
@scheduler.daily.at("02:00")
async def nightly_cleanup():
await clean_old_sessions(days=10)
scheduler.start()
访问 http://localhost:8000/scheduler/ 查看控制面板(实时状态、暂停/恢复、执行历史、失败队列)。
常用调度方式
### 间隔调度
@scheduler.every(10).seconds
@scheduler.every(5).minutes
@scheduler.every(2).hours
### 定时调度
@scheduler.daily.at("09:00")
@scheduler.hourly.at(":30")
@scheduler.weekly.monday.at("10:00")
@scheduler.weekly.weekdays.at("09:00")
### Cron 表达式
@scheduler.cron("0 9 * * MON-FRI") # 工作日9点
@scheduler.cron("*/15 * * * *") # 每15分钟
### 一次性任务
@scheduler.once(60) # 60秒后执行一次
@scheduler.at("2026-12-25 00:00:00") # 指定时间执行
### 时区支持
@scheduler.daily.at("09:00", tz="Asia/Shanghai")
任务控制
### 超时 & 重试(指数退避)
@scheduler.every(5).minutes.timeout(30).retries(3)
def flaky_api_call():
...
### 暂停 / 恢复 / 取消
scheduler.pause_job("job_0")
scheduler.resume_job("job_0")
scheduler.cancel_job("job_0")
生产配置
scheduler = FastScheduler(
storage="sqlmodel", # 持久化到数据库(默认 json 文件)
database_url="sqlite:///scheduler.db", # 或 postgresql://...
max_workers=20, # 并发任务数
max_history=5000, # 最大历史记录数
history_retention_days=8, # 历史保留天数
max_dead_letters=500, # 最大失败记录数
)
控制面板 API
面板自带 REST API,可用于外部集成:
| 端点 | 方法 | 说明 |
|---|---|---|
/scheduler/api/jobs |
GET | 所有任务列表 |
/scheduler/api/jobs/{id}/pause |
POST | 暂停任务 |
/scheduler/api/jobs/{id}/resume |
POST | 恢复任务 |
/scheduler/api/jobs/{id}/run |
POST | 立即执行 |
/scheduler/api/history |
GET | 执行历史 |
/scheduler/api/dead-letters |
GET | 失败队列 |
配置管理(内存缓存 + 热更新)
配置管理
核心思路:配置表从数据库加载到内存(pandas DataFrame),运行时通过文件锁检测变更并热更新,避免每次请求都查库。
ConfigTable 模式
架构
核心实现
import pandas as pd
import json
import time
from pathlib import Path
from functools import wraps
from sqlmodel import create_engine
engine_config = create_engine('sqlite:///volume/dbs/config.db')
class ConfigTable:
def __init__(self):
self.path = Path('volume/dbs/update.lock')
if not self.path.exists():
self.path.touch()
self.update_time = self.path.stat().st_mtime
self.master_table = self._load_master_table()
# 启动时一次性加载所有配置表到内存
self.tables = {
tn: pd.read_sql(tn, engine_config, index_col=index_col)
if index_col else pd.read_sql(tn, engine_config)
for tn, index_col in self.master_table['index_col'].items()
}
@property
def config(self):
"""通过 tables.config['表名', '列名'] 访问"""
return self.master_table.loc
@staticmethod
def _load_master_table():
mt = pd.read_sql_table('table_master', engine_config, index_col='table_name')
mt['index_col'] = mt['index_col'].apply(json.loads)
mt['bool_col'] = mt['bool_col'].apply(json.loads)
return mt
@staticmethod
def ensure_new(func):
"""装饰器:访问前检查文件锁时间戳,有变化才重新加载"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.path.stat().st_mtime != self.update_time:
new_master = self._load_master_table()
# 只重新加载发生变化的表(对比 update_time 列)
old_time = self.master_table['update_time']
new_time = new_master['update_time']
changed = new_time[new_time != old_time.reindex(new_time.index)].index
for tn in changed:
self.tables[tn] = pd.read_sql(
tn, engine_config,
index_col=new_master.loc[tn, 'index_col']
)
self.master_table = new_master
self.update_time = self.path.stat().st_mtime
return func(self, *args, **kwargs)
return wrapper
@ensure_new
def __getitem__(self, table_name: str) -> pd.DataFrame:
if table_name not in self.tables:
raise KeyError(f'Table {table_name} not found')
return self.tables[table_name]
tables = ConfigTable()
更新机制
通过 API 端点触发更新,从外部数据源拉取最新数据:
async def update_table(self, table_name: str):
# 1. 从外部源拉取数据
df = pd.read_csv(f'{source_url}/csv/{table_name}', ...)
# 2. 写入 SQLite(持久化)
df.to_sql(table_name, engine_config, index=True, if_exists='replace')
# 3. 更新 master_table 的时间戳
self.master_table.loc[table_name, 'update_time'] = int(time.time() * 1000)
self.master_table.to_sql('master_table', engine_config, ...)
# 4. 更新内存中的表
self.tables[table_name] = df
# 5. 触摸文件锁,通知其他进程
self.path.touch()
self.update_time = self.path.stat().st_mtime
safe_at 安全取值(带回退)
注册 pandas accessor,当索引不存在时按回退规则查找:
@register_dataframe_accessor("safe_at")
class SafeLocAccessor:
# 回退规则:索引找不到时,尝试这些默认值
fallback_rules = {
'game_cd': 0, # 找不到特定游戏配置 → 用默认配置
'lang': 'en', # 找不到特定语言 → 用英语
}
def __getitem__(self, key: tuple):
k0, k1 = key # (索引值, 列名)
if k0 in self._df.index:
return self._df.at[key]
# 多级索引:用笛卡尔积遍历所有回退组合
...
使用方式:
### 直接取值(可能 KeyError)
value = tables['prompt_config'].at['game_a', 'system_prompt']
### 安全取值(自动回退)
value = tables['prompt_config'].safe_at['game_a', 'system_prompt']
### 如果 game_a 不存在,自动尝试 game_cd=0 的默认配置
简单场景:启动时加载到全局字典
对于不需要热更新的简单配置,直接在 lifespan 中加载:
### 全局缓存
game_expire_dict = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时从API加载
response = await client.get(f'{api_url}/config')
game_expire_dict.update(response.json())
yield
何时用哪种
| 场景 | 方案 |
|---|---|
| 简单、很少变的配置 | 全局字典,启动时加载 |
| 需要运行时更新、多表管理 | ConfigTable 模式 |
| 需要多进程同步 | ConfigTable + 文件锁 |
对话前端(Chainlit)
对话前端
方案选择
根据需求选择前端方案:
| 方案 | 适合场景 | 不适合 |
|---|---|---|
| 自定义 HTML + SSE | 需要定制 UI、混合业务组件、产品级交互 | — |
| Chainlit | 纯对话的快速原型,不需要自定义布局 | 需要改布局、加自定义组件 |
默认选自定义 HTML。 有 AI 辅助写前端的情况下,HTML + SSE 的开发速度不比 Chainlit 慢,但灵活性远超。 只有用户明确要求用 Chainlit 或只需要最简单的对话 demo 时才用 Chainlit。
自定义 HTML + SSE(推荐)
后端:FastAPI SSE 端点
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
app = FastAPI()
### 挂载静态文件(HTML/CSS/JS)
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/", response_class=HTMLResponse)
async def index():
with open("static/index.html") as f:
return f.read()
@app.post("/chat")
async def chat(request: Request):
body = await request.json()
message = body["message"]
async def generate():
# 替换成你的 LLM 调用(LangGraph / OpenAI / Anthropic 等)
async for chunk in your_llm_stream(message):
yield f"data: {chunk}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate(), media_type="text/event-stream")
前端:最小 SSE 聊天页面
让 AI 直接写 HTML + CSS + JS 即可,核心就是一个 fetch + EventSource / getReader():
// 流式读取 SSE 响应
const response = await fetch("/chat", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ message: userInput })
});
const reader = response.body.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
if (done) break;
const text = decoder.decode(value);
// 解析 SSE data: 行,追加到聊天气泡里
appendToChat(text);
}
自定义 HTML 方案的优势
- 布局完全自由(侧边栏、多面板、数据看板、嵌入图表等)
- AI 写 HTML/CSS/JS 极其成熟,几乎不出错
- 可以直接用 Tailwind CDN、Alpine.js 等轻量工具
- WebSocket / SSE 任选,通信方式不受框架限制
Chainlit(仅快速原型)
仅当用户明确要求或只需最简对话 demo 时使用。
与 FastAPI 集成
from fastapi import FastAPI
from chainlit.utils import mount_chainlit
app = FastAPI()
@app.get("/api/health")
def health():
return {"status": "ok"}
mount_chainlit(app=app, target="my_cl_app.py", path="/chainlit")
Chainlit 应用文件
import chainlit as cl
@cl.on_chat_start
async def on_chat_start():
graph = ... # 编译好的 LangGraph
cl.user_session.set("graph", graph)
@cl.on_message
async def on_message(message: cl.Message):
graph = cl.user_session.get("graph")
msg = cl.Message(content="")
async for chunk in graph.astream(
{"messages": [{"role": "user", "content": message.content}]},
stream_mode="messages"
):
await msg.stream_token(chunk.content)
await msg.send()
Chainlit 注意事项
- 有自己的配置文件
.chainlit/config.toml - 开发时用
chainlit run app.py -w启用热重载 - 高度自定义 UI 布局时会很痛苦,这时应该切换到自定义 HTML 方案
运行时补丁(dowhen)
运行时补丁 (dowhen)
使用 dowhen 库在运行时对第三方库的方法进行精确补丁,无需 fork 或等待上游修复。
适用场景
- 第三方库的某个方法有 bug,但不想 fork 整个库
- 需要在特定条件下改变第三方库的行为
- 临时兼容性修复(等上游合并 PR 期间的过渡方案)
基本语法
from dowhen import when
### 在 SomeClass.some_method 中,找到 'target_line' 这行代码
### 当 condition 为真时,用 .do() 中的代码替换/插入
(
when(
SomeClass.some_method,
'target_line_of_code', # 定位到方法中的哪一行
condition='some_condition' # 可选:仅在条件成立时执行
)
.do("replacement_code")
)
实际案例
案例1:修改 Langfuse 回调的日志级别
当 LangGraph 流被主动中断(GeneratorExit)或前置检查异常时,Langfuse 会记录 ERROR 级别。用 dowhen 改为不记录:
from langfuse.langchain.CallbackHandler import LangchainCallbackHandler
from dowhen import when
(
when(
LangchainCallbackHandler.on_chain_error,
'observation = self._detach_observation(run_id)',
condition=(
'isinstance(error, GeneratorExit)'
'or "PreconsicousError" in str(type(error))'
'or "CancelledError" in str(type(error))'
)
)
.do("level = None") # 将 level 设为 None,跳过错误记录
)
案例2:修改工具错误的记录方式
将特定业务异常的工具错误从 ERROR 降级为 WARNING:
do2 = '''
if observation is not None:
observation.update(
status_message=str(error),
level="WARNING",
input=kwargs.get("inputs"),
cost_details={"total": 0},
).end()
observation = None
'''
(
when(
LangchainCallbackHandler.on_tool_error,
'if observation is not None:',
condition='(error.__class__.__name__ in ["ChangeToolError"])'
)
.do(do2)
)
案例3:兼容 DeepSeek reasoner 的多轮对话
DeepSeek reasoner 模型要求 assistant 消息必须包含 reasoning_content 字段,但 langchain-deepseek 没有正确传递:
from langchain_deepseek import ChatDeepSeek
do4 = '''
messages = self._convert_input(input_).to_messages()
reasoning_content_map = {}
for i, msg in enumerate(messages):
if hasattr(msg, "additional_kwargs"):
rc = msg.additional_kwargs.get("reasoning_content")
if rc is not None:
reasoning_content_map[i] = rc
for i, message in enumerate(payload["messages"]):
if message["role"] == "assistant":
if i in reasoning_content_map:
message["reasoning_content"] = reasoning_content_map[i]
elif "deepseek-reasoner" in self.model_name and "reasoning_content" not in message:
message["reasoning_content"] = " "
'''
(
when(
ChatDeepSeek._get_request_payload,
'for message in payload["messages"]:',
)
.do(do4)
)
使用建议
- 补丁代码集中放在一个文件中(如
utils/dowhen_man.py),在应用启动时 import - 补丁中添加注释说明:为什么需要这个补丁、对应的上游 issue
- 当上游修复后,及时移除补丁
condition参数可以精确控制触发条件,避免影响正常流程
日志管理
FastAPI 日志管理
基本配置
使用 loguru 作为日志库,输出到控制台,不写文件:
错误记录
对于异常,使用 logger.exception(e) 记录完整堆栈:
AsyncIO 调试模式
排查异步阻塞问题时,启用 AsyncIO debug 模式。当某个任务耗时超过 100ms 时会打印警告:
输出示例:
这能帮你快速定位哪个 async def 路由里写了阻塞调用。
请求日志中间件
如需记录每个请求的耗时,用纯 ASGI 中间件(不用 BaseHTTPMiddleware,性能更好):
import time
from loguru import logger
class RequestLogMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start = time.perf_counter()
await self.app(scope, receive, send)
elapsed = time.perf_counter() - start
logger.info(f"{scope['method']} {scope['path']} {elapsed:.3f}s")
app.add_middleware(RequestLogMiddleware)
测试
FastAPI 测试
核心原则
从项目一开始就用异步测试客户端,不要等到后期再改——会遇到事件循环错误。
异步测试客户端
用 HTTPX 的 AsyncClient 替代 Starlette 的 TestClient:
import pytest
from httpx import AsyncClient, ASGITransport
from src.main import app
@pytest.fixture
async def client():
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
如果用了 lifespan 事件(startup/shutdown),需要 asgi-lifespan 包:
from asgi_lifespan import LifespanManager
from httpx import AsyncClient, ASGITransport
@pytest.fixture
async def client():
async with LifespanManager(app) as manager:
transport = ASGITransport(app=manager.app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
测试标记
用 pytest.mark.anyio 替代 pytest.mark.asyncio(anyio 已作为 Starlette 的依赖自动安装):
import pytest
@pytest.mark.anyio
async def test_create_post(client: AsyncClient):
resp = await client.post("/posts", json={"title": "Test", "content": "Hello"})
assert resp.status_code == 201
限定只跑 asyncio 后端(否则 anyio 默认还会跑一遍 trio):
跨模块测试
用 fixture 封装前置条件(如认证),测试本身只关注「发请求 → 验响应」:
@pytest.fixture
async def auth_headers(client: AsyncClient):
await client.post("/auth/register", json={...})
resp = await client.post("/auth/login", json={...})
token = resp.json()["access_token"]
return {"Authorization": f"Bearer {token}"}
@pytest.mark.anyio
async def test_create_post_requires_auth(client: AsyncClient):
resp = await client.post("/posts", json={"title": "No Auth"})
assert resp.status_code == 401
@pytest.mark.anyio
async def test_create_post_success(client: AsyncClient, auth_headers: dict):
resp = await client.post("/posts", json={"title": "OK"}, headers=auth_headers)
assert resp.status_code == 201
测试目录结构
与 src 目录一一镜像:
tests/
├── conftest.py # 共享 fixture(client, auth_headers)
├── auth/
│ └── test_login.py
├── posts/
│ └── test_create.py
└── users/
└── test_profile.py
关键原则
- 测接口行为,不测内部实现。跨模块调用对测试来说是透明的
- 尽量少 mock,走真实链路(集成测试)。只有调外部第三方 API 时才 mock
- 用 fixture 封装前置条件,测试用例本身只关注「发请求 → 验响应」
部署
FastAPI 部署
性能依赖
安装 uvloop 和 httptools 以获得更好的性能(Uvicorn 检测到会自动启用):
uvloop 不支持 Windows。可用环境标记处理:
uvloop; sys_platform != 'win32'
生产启动
或使用 gunicorn + uvicorn worker:
API 文档控制
非公开 API 在生产环境隐藏文档:
from pydantic_settings import BaseSettings
class Config(BaseSettings):
ENVIRONMENT: str = "production"
settings = Config()
SHOW_DOCS_ENVIRONMENTS = ("local", "staging")
app_configs = {"title": "My API"}
if settings.ENVIRONMENT not in SHOW_DOCS_ENVIRONMENTS:
app_configs["openapi_url"] = None
app = FastAPI(**app_configs)
线程池调整
默认线程池只有 40 个线程。高并发场景可在 lifespan 里调大:
import anyio
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
limiter = anyio.to_thread.current_default_thread_limiter()
limiter.total_tokens = 100
yield
app = FastAPI(lifespan=lifespan)
中间件性能
避免使用 BaseHTTPMiddleware(@app.middleware("http") 也是它的包装),有性能损耗。
需要中间件时优先实现纯 ASGI 中间件:
监控
FastAPI 监控
AsyncIO 调试模式
快速定位阻塞事件循环的端点:
任何耗时超过 100ms 的任务会打印警告,帮你找到 async def 里的阻塞调用。
线程池监控
实时监控线程池使用情况,排查线程耗尽问题:
import anyio
from anyio.to_thread import current_default_thread_limiter
async def monitor_thread_limiter():
limiter = current_default_thread_limiter()
threads_in_use = limiter.borrowed_tokens
while True:
if threads_in_use != limiter.borrowed_tokens:
print(f"Threads in use: {limiter.borrowed_tokens}")
threads_in_use = limiter.borrowed_tokens
await anyio.sleep(0)
可在 lifespan 里启动:
from contextlib import asynccontextmanager
import anyio
@asynccontextmanager
async def lifespan(app: FastAPI):
async with anyio.create_task_group() as tg:
tg.start_soon(monitor_thread_limiter)
yield
请求耗时追踪
用纯 ASGI 中间件记录每个请求的处理时间:
import time
from loguru import logger
class TimingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start = time.perf_counter()
await self.app(scope, receive, send)
elapsed = time.perf_counter() - start
if elapsed > 1.0:
logger.warning(f"Slow request: {scope['method']} {scope['path']} {elapsed:.3f}s")
else:
logger.info(f"{scope['method']} {scope['path']} {elapsed:.3f}s")
Sentry 集成
生产环境接入 Sentry 捕获异常:
事后检查(异步 / 依赖注入 / 安全 / 数据库)
FastAPI 事后检查清单
代码写完后逐条过一遍。这些是 AI 容易忽略的常见问题。
异步陷阱
- [ ]
async def路由里有没有阻塞调用?(time.sleep、同步 HTTP 请求、同步文件读写 → 会卡死事件循环) - [ ] 如果不确定是否阻塞,是否改用了普通
def?(FastAPI 自动放线程池) - [ ] 必须在 async 路由里调同步库时,是否用了
await run_in_threadpool(sync_func, ...)? - [ ] CPU 密集型任务(计算、转码)是否交给了 ProcessPoolExecutor 或任务队列,而不是线程池?
- [ ] 线程池默认只有 40 个线程,高并发场景是否够用?
依赖注入
- [ ] 数据库级校验(资源是否存在、权限检查)是否放在了
Depends()里,而不是写在 router 函数体内? - [ ] 多个端点共用的校验逻辑是否抽成了依赖,而不是复制粘贴?
- [ ] 依赖函数是否优先用了
async def?(非 async 的依赖也会被扔进线程池) - [ ] 链式依赖是否利用了 FastAPI 的缓存机制?(同一请求内相同依赖只执行一次)
安全与文档
- [ ] 非公开 API 是否隐藏了文档?(生产环境设置
openapi_url=None) - [ ] 有没有在 Pydantic 模型里暴露了不该返回的字段?(如 password_hash)
- [ ] 是否充分利用了 Pydantic 的校验能力?(Field 约束、正则、枚举,而不是在 service 里手写 if)
数据库
- [ ] SQLModel/SQLAlchemy 模型的命名是否统一?(小写蛇形、单数形式,如
post不是posts) - [ ] datetime 字段是否用
_at后缀,date 字段是否用_date后缀? - [ ] 数据库迁移文件(Alembic)是否有描述性名称?(如
2024-08-24_add_post_tags.py) - [ ] 复杂查询是否交给了 SQL 而不是在 Python 里循环处理?
- [ ] 嵌套对象的响应是否在 SQL 里用
json_build_object聚合,而不是 Python 里手动拼?
响应
- [ ] 所有接口是否都用了
Response[T]泛型包装? - [ ]
response_model是否正确设置?(注意:FastAPI 会用 response_model 再创建一次 Pydantic 对象做校验) - [ ] 是否设置了合适的
status_code?(POST 创建用 201,DELETE 用 204 等)
错误处理
- [ ] 是否定义了模块专属异常类?(
raise PostNotFound()比raise HTTPException(404, "xxx")更清晰) - [ ] Pydantic 里的
ValueError会自动变成 ValidationError 返回给用户,错误信息是否友好?