📢 1. 进度通知(Progress)
当Server执行耗时较长的任务时,进度通知机制可以让用户了解任务的执行状态,提升用户体验。
1.1 为什么需要进度通知?
- 大文件处理(上传、下载、分析)
- 复杂数据分析任务
- 批量操作(处理大量记录)
- 外部API调用(可能有延迟)
1.2 实现进度通知
MCP协议支持通过notifications/progress发送进度更新:
from mcp.server import Server
from mcp.types import ProgressNotification
import asyncio
app = Server("progress-demo")
@app.call_tool()
async def process_large_file(file_path: str, progress_token: str = None):
"""处理大文件,带进度通知"""
# 模拟大文件处理
total_steps = 100
for i in range(total_steps):
# 执行实际工作
await asyncio.sleep(0.1)
# 发送进度通知(如果有progress_token)
if progress_token:
await app.request_context.session.send_notification(
"notifications/progress",
{
"progressToken": progress_token,
"progress": i + 1,
"total": total_steps,
"message": f"正在处理第 {i+1}/{total_steps} 部分..."
}
)
return {
"content": [{
"type": "text",
"text": "✅ 文件处理完成!"
}]
}
💡 提示:progress_token由Client在调用时提供,Server需要检查是否存在再发送通知。
1.3 Client端接收进度
from mcp import ClientSession
async def call_with_progress(session: ClientSession):
# 生成唯一的progress token
progress_token = str(uuid.uuid4())
# 设置进度回调
def on_progress(progress: int, total: int, message: str):
percentage = (progress / total) * 100
print(f"进度: {percentage:.1f}% - {message}")
# 调用工具时传入progress token
result = await session.call_tool(
"process_large_file",
{"file_path": "/path/to/large/file.zip"},
progress_token=progress_token
)
return result
🛑 2. 取消请求(Cancellation)
允许用户取消正在执行的长时间任务,避免资源浪费。
2.1 Server端支持取消
import asyncio
from mcp.server import Server
from mcp.types import CancelledNotification
app = Server("cancellable-demo")
# 存储正在运行的任务
running_tasks = {}
@app.call_tool()
async def long_running_task(task_id: str, duration: int):
"""可取消的长时任务"""
cancel_event = asyncio.Event()
running_tasks[task_id] = cancel_event
try:
for i in range(duration):
# 检查是否被取消
if cancel_event.is_set():
return {
"content": [{
"type": "text",
"text": f"❌ 任务在 {i} 秒时被用户取消"
}],
"isError": False
}
await asyncio.sleep(1)
return {
"content": [{
"type": "text",
"text": "✅ 任务成功完成!"
}]
}
finally:
# 清理任务记录
running_tasks.pop(task_id, None)
@app.on_notification("cancelled")
async def handle_cancel(notification: CancelledNotification):
"""处理取消通知"""
request_id = notification.params["requestId"]
# 查找并取消对应的任务
if request_id in running_tasks:
running_tasks[request_id].set()
print(f"任务 {request_id} 已标记为取消")
2.2 使用asyncio.Task的取消机制
@app.call_tool()
async def cancellable_computation(data: list):
"""使用asyncio原生取消机制"""
task = asyncio.current_task()
results = []
for i, item in enumerate(data):
# 检查任务是否被取消
if task.cancelled():
raise asyncio.CancelledError("任务被用户取消")
# 模拟计算
result = await expensive_operation(item)
results.append(result)
# 定期检查取消状态
await asyncio.sleep(0)
return {"content": [{"type": "text", "text": str(results)}]}
⚠️ 警告:取消请求只是发送信号,Server需要主动检查并响应。确保在关键操作前检查取消状态。
🎲 3. 采样/补全(Sampling)
Sampling是MCP的一个强大特性,允许Server请求Client的LLM生成内容。这使得Server可以利用AI能力来增强功能。
3.1 Sampling使用场景
- 智能代码生成:根据上下文生成代码片段
- 内容优化:改进用户输入的文本质量
- 智能分析:分析数据并提供AI见解
- 对话增强:生成自然语言响应
3.2 Server请求Sampling
from mcp.server import Server
from mcp.types import SamplingMessage
app = Server("sampling-demo")
@app.call_tool()
async def generate_documentation(code: str):
"""为代码生成文档注释"""
# 构建sampling请求
messages = [
SamplingMessage(
role="system",
content={
"type": "text",
"text": "你是一个专业的代码文档工程师。为给定的代码生成清晰、详细的文档注释。"
}
),
SamplingMessage(
role="user",
content={
"type": "text",
"text": f"请为以下代码生成文档注释:\n\n```python\n{code}\n```"
}
)
]
# 请求LLM生成内容
response = await app.request_context.session.send_request(
"sampling/createMessage",
{
"messages": messages,
"modelPreferences": {
"hints": ["claude-3-5-sonnet"],
"intelligencePriority": 0.8,
"speedPriority": 0.3
},
"maxTokens": 2000
}
)
return {
"content": [{
"type": "text",
"text": f"生成的文档:\n\n{response['content']['text']}"
}]
}
3.3 高级Sampling示例:智能数据分析
@app.call_tool()
async def analyze_sales_data(csv_data: str):
"""分析销售数据并生成AI洞察报告"""
# 先进行基础数据处理
import pandas as pd
from io import StringIO
df = pd.read_csv(StringIO(csv_data))
basic_stats = df.describe().to_string()
# 请求AI进行深度分析
messages = [
SamplingMessage(
role="system",
content={
"type": "text",
"text": "你是数据分析师,擅长从销售数据中发现趋势和机会。"
}
),
SamplingMessage(
role="user",
content={
"type": "text",
"text": f"""请分析以下销售数据,提供:
1. 关键发现(3-5条)
2. 趋势分析
3. 改进建议
基础统计数据:
{basic_stats}
原始数据摘要:
- 总记录数: {len(df)}
- 时间范围: {df['date'].min()} 至 {df['date'].max()}
- 总销售额: ${df['amount'].sum():,.2f}
"""
}
)
]
analysis = await app.request_context.session.send_request(
"sampling/createMessage",
{
"messages": messages,
"maxTokens": 3000
}
)
return {
"content": [{
"type": "text",
"text": analysis["content"]["text"]
}],
"metadata": {
"data_points": len(df),
"total_revenue": float(df['amount'].sum())
}
}
✅ 最佳实践:Sampling应该作为增强功能,而不是必需功能。Server应该在没有Sampling的情况下也能正常工作。
📁 4. Roots - 文件系统安全访问
Roots机制允许Client限制Server可以访问的文件系统范围,是重要的安全防护措施。
4.1 什么是Roots?
Roots是Client声明的一组目录路径,Server只能访问这些目录及其子目录内的文件:
{
"roots": [
{
"uri": "file:///home/user/projects",
"name": "项目目录"
},
{
"uri": "file:///home/user/documents",
"name": "文档目录"
}
]
}
4.2 Server端读取Roots
from mcp.server import Server
from pathlib import Path
import os
app = Server("roots-demo")
def get_roots() -> list[Path]:
"""获取Client声明的roots"""
roots = app.request_context.session.roots
return [Path(root.uri.replace("file://", "")) for root in roots]
def is_path_allowed(target_path: str) -> bool:
"""检查路径是否在允许的roots范围内"""
target = Path(target_path).resolve()
allowed_roots = get_roots()
for root in allowed_roots:
try:
target.relative_to(root)
return True
except ValueError:
continue
return False
@app.call_tool()
async def safe_read_file(file_path: str):
"""安全地读取文件(受roots限制)"""
# 安全检查
if not is_path_allowed(file_path):
return {
"content": [{
"type": "text",
"text": f"❌ 访问被拒绝:{file_path} 不在允许的roots范围内"
}],
"isError": True
}
# 读取文件
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return {
"content": [{
"type": "text",
"text": content
}]
}
except Exception as e:
return {
"content": [{
"type": "text",
"text": f"❌ 读取失败: {str(e)}"
}],
"isError": True
}
4.3 监听Roots变化
@app.on_notification("roots/list_changed")
async def on_roots_changed():
"""当roots发生变化时重新加载"""
new_roots = await app.request_context.session.send_request(
"roots/list",
{}
)
print(f"Roots已更新: {new_roots}")
# 清理缓存或重新验证路径
🔒 5. 安全最佳实践
5.1 输入验证
from pydantic import BaseModel, validator, Field
import re
class FileOperationInput(BaseModel):
path: str = Field(..., description="文件路径")
content: str = Field(default="", description="文件内容")
@validator('path')
def validate_path(cls, v):
# 防止路径遍历攻击
if '..' in v or '~' in v:
raise ValueError("路径包含非法字符")
# 只允许特定扩展名
allowed_extensions = {'.txt', '.md', '.py', '.json'}
ext = Path(v).suffix.lower()
if ext not in allowed_extensions:
raise ValueError(f"不支持的文件类型: {ext}")
return v
@validator('content')
def validate_content(cls, v):
# 限制文件大小
max_size = 1024 * 1024 # 1MB
if len(v.encode('utf-8')) > max_size:
raise ValueError("文件内容超过最大限制(1MB)")
# 检查危险内容
dangerous_patterns = [
r'eval\s*\(',
r'exec\s*\(',
r'subprocess\.call',
r'os\.system'
]
for pattern in dangerous_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError("内容包含潜在危险代码")
return v
# 在工具中使用验证
@app.call_tool()
async def write_file(arguments: dict):
try:
validated = FileOperationInput(**arguments)
# 继续处理...
except ValueError as e:
return {
"content": [{"type": "text", "text": f"输入验证失败: {e}"}],
"isError": True
}
5.2 权限控制
from enum import Enum
from functools import wraps
class PermissionLevel(Enum):
READ = "read"
WRITE = "write"
EXECUTE = "execute"
ADMIN = "admin"
# 权限检查装饰器
def require_permission(level: PermissionLevel):
def decorator(func):
@wraps(func)
async def wrapper(arguments: dict, **kwargs):
# 获取用户权限(从Client的metadata或session)
user_level = kwargs.get('permissions', {}).get('level', PermissionLevel.READ)
# 权限检查逻辑
permission_order = [PermissionLevel.READ, PermissionLevel.WRITE,
PermissionLevel.EXECUTE, PermissionLevel.ADMIN]
if permission_order.index(user_level) < permission_order.index(level):
return {
"content": [{
"type": "text",
"text": f"❌ 权限不足,需要 {level.value} 权限"
}],
"isError": True
}
return await func(arguments, **kwargs)
return wrapper
return decorator
@app.call_tool()
@require_permission(PermissionLevel.WRITE)
async def delete_file(arguments: dict, **kwargs):
"""删除文件 - 需要写权限"""
# 实现...
5.3 沙箱执行
import subprocess
import tempfile
import os
from pathlib import Path
async def sandbox_execute(code: str, timeout: int = 30) -> dict:
"""在沙箱环境中执行代码"""
# 创建临时目录作为沙箱
with tempfile.TemporaryDirectory() as sandbox_dir:
# 写入代码文件
code_file = Path(sandbox_dir) / "script.py"
code_file.write_text(code)
# 使用受限环境执行
try:
result = subprocess.run(
["python", "-c", f"""
import sys
sys.path = [] # 清除模块搜索路径
exec(open('{code_file}').read())
"""],
capture_output=True,
text=True,
timeout=timeout,
cwd=sandbox_dir,
# 环境变量限制
env={"PYTHONPATH": "", "PATH": "/usr/bin"}
)
return {
"stdout": result.stdout,
"stderr": result.stderr,
"returncode": result.returncode
}
except subprocess.TimeoutExpired:
return {"error": "执行超时"}
except Exception as e:
return {"error": str(e)}
# 更安全的方案:使用Docker
async def docker_execute(code: str, image: str = "python:3.11-slim") -> dict:
"""在Docker容器中执行代码"""
with tempfile.TemporaryDirectory() as tmpdir:
code_file = Path(tmpdir) / "script.py"
code_file.write_text(code)
result = subprocess.run(
[
"docker", "run", "--rm",
"-v", f"{tmpdir}:/workspace:ro", # 只读挂载
"--network", "none", # 禁用网络
"--memory", "128m", # 内存限制
"--cpus", "0.5", # CPU限制
"--timeout", "30",
image,
"python", "/workspace/script.py"
],
capture_output=True,
text=True,
timeout=35
)
return {
"stdout": result.stdout,
"stderr": result.stderr,
"returncode": result.returncode
}
5.4 敏感信息保护
import re
from typing import Any
class SensitiveDataFilter:
"""敏感数据过滤器"""
# 正则模式
PATTERNS = {
'api_key': r'[a-zA-Z0-9_-]{32,}', # API密钥
'password': r'password[\s]*[=:]+[\s]*[^\s]+',
'secret': r'secret[\s]*[=:]+[\s]*[^\s]+',
'token': r'[a-zA-Z0-9_-]{20,}\.[a-zA-Z0-9_-]{20,}', # JWT
'email': r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}',
'credit_card': r'\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}',
}
REPLACEMENTS = {
'api_key': '',
'password': '',
'secret': '',
'token': '',
'email': lambda m: f"{m.group(0).split('@')[0][:2]}***@{m.group(0).split('@')[1]}",
'credit_card': '',
}
@classmethod
def filter_text(cls, text: str) -> str:
"""过滤文本中的敏感信息"""
for key, pattern in cls.PATTERNS.items():
replacement = cls.REPLACEMENTS[key]
if callable(replacement):
text = re.sub(pattern, replacement, text)
else:
text = re.sub(pattern, replacement, text)
return text
@classmethod
def filter_dict(cls, data: dict) -> dict:
"""递归过滤字典中的敏感信息"""
if isinstance(data, dict):
return {
k: cls.filter_dict(v) if isinstance(v, (dict, list))
else cls.filter_text(str(v)) if isinstance(v, str) else v
for k, v in data.items()
}
elif isinstance(data, list):
return [cls.filter_dict(item) for item in data]
return data
# 在Server中使用
@app.call_tool()
async def process_data(data: dict):
"""处理数据(自动过滤敏感信息)"""
# 处理前记录日志(已过滤)
safe_data = SensitiveDataFilter.filter_dict(data)
print(f"处理数据: {safe_data}")
# 实际处理使用原始数据
result = await actual_processing(data)
# 返回结果也进行过滤
return {
"content": [{
"type": "text",
"text": SensitiveDataFilter.filter_text(str(result))
}]
}
⚡ 6. 性能优化
6.1 连接复用
import asyncio
from contextlib import asynccontextmanager
from mcp.client import ClientSession
class MCPConnectionPool:
"""MCP连接池"""
def __init__(self, max_connections: int = 10):
self.max_connections = max_connections
self._pool = asyncio.Queue()
self._used = set()
self._lock = asyncio.Lock()
async def initialize(self, server_params):
"""初始化连接池"""
for _ in range(self.max_connections):
session = await self._create_session(server_params)
await self._pool.put(session)
@asynccontextmanager
async def acquire(self):
"""获取连接"""
session = await self._pool.get()
self._used.add(id(session))
try:
yield session
finally:
self._used.discard(id(session))
await self._pool.put(session)
async def _create_session(self, params):
"""创建新会话"""
# 实际创建逻辑
pass
# 使用连接池
pool = MCPConnectionPool(max_connections=5)
await pool.initialize(server_params)
async with pool.acquire() as session:
result = await session.call_tool("some_tool", {})
6.2 批量操作
from typing import List
import asyncio
@app.call_tool()
async def batch_process(items: List[str], batch_size: int = 10):
"""批量处理项目"""
results = []
# 分批处理
for i in range(0, len(items), batch_size):
batch = items[i:i + batch_size]
# 并行处理批次内的项目
batch_results = await asyncio.gather(
*[process_single(item) for item in batch],
return_exceptions=True
)
results.extend(batch_results)
# 发送进度通知
await send_progress(i + len(batch), len(items))
return {
"content": [{
"type": "text",
"text": f"处理完成: {len(results)}/{len(items)}"
}],
"metadata": {"results": results}
}
@app.call_tool()
async def parallel_queries(queries: List[str]):
"""并行执行多个查询"""
semaphore = asyncio.Semaphore(5) # 限制并发数
async def limited_query(query: str):
async with semaphore:
return await execute_query(query)
results = await asyncio.gather(
*[limited_query(q) for q in queries],
return_exceptions=True
)
# 处理结果和异常
successful = [r for r in results if not isinstance(r, Exception)]
failed = [str(r) for r in results if isinstance(r, Exception)]
return {
"content": [{
"type": "text",
"text": f"成功: {len(successful)}, 失败: {len(failed)}"
}]
}
6.3 缓存策略
import functools
import time
from typing import Any, Optional
import hashlib
class CacheManager:
"""缓存管理器"""
def __init__(self, default_ttl: int = 300):
self._cache = {}
self._ttl = default_ttl
def _make_key(self, *args, **kwargs) -> str:
"""生成缓存键"""
key_data = f"{args}:{sorted(kwargs.items())}"
return hashlib.md5(key_data.encode()).hexdigest()
def get(self, key: str) -> Optional[Any]:
"""获取缓存值"""
if key in self._cache:
value, expire_time = self._cache[key]
if time.time() < expire_time:
return value
else:
del self._cache[key]
return None
def set(self, key: str, value: Any, ttl: Optional[int] = None):
"""设置缓存值"""
expire_time = time.time() + (ttl or self._ttl)
self._cache[key] = (value, expire_time)
def cached(self, ttl: Optional[int] = None):
"""缓存装饰器"""
def decorator(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# 生成缓存键
cache_key = f"{func.__name__}:{self._make_key(*args, **kwargs)}"
# 尝试从缓存获取
cached_value = self.get(cache_key)
if cached_value is not None:
return cached_value
# 执行函数
result = await func(*args, **kwargs)
# 缓存结果
self.set(cache_key, result, ttl)
return result
return wrapper
return decorator
# 使用缓存
cache = CacheManager(default_ttl=60)
@app.call_tool()
@cache.cached(ttl=300) # 缓存5分钟
async def get_schema(database: str):
"""获取数据库结构(缓存结果)"""
# 昂贵的数据库查询
schema = await fetch_database_schema(database)
return schema
@app.read_resource()
@cache.cached(ttl=60)
async def read_file_resource(uri: str):
"""读取文件资源(缓存1分钟)"""
file_path = uri.replace("file://", "")
with open(file_path, 'r') as f:
return f.read()
🛠️ 7. 错误处理最佳实践
7.1 有意义的错误信息
from enum import Enum
from typing import Optional
import json
class ErrorCode(Enum):
"""错误代码枚举"""
# 系统错误 1xxx
INTERNAL_ERROR = "1000"
SERVICE_UNAVAILABLE = "1001"
TIMEOUT = "1002"
# 请求错误 2xxx
INVALID_PARAMS = "2000"
MISSING_REQUIRED_PARAM = "2001"
INVALID_FORMAT = "2002"
# 业务错误 3xxx
RESOURCE_NOT_FOUND = "3000"
PERMISSION_DENIED = "3001"
RATE_LIMITED = "3002"
# 外部错误 4xxx
EXTERNAL_API_ERROR = "4000"
DATABASE_ERROR = "4001"
class MCPError(Exception):
"""MCP标准错误"""
def __init__(
self,
code: ErrorCode,
message: str,
details: Optional[dict] = None,
suggestion: Optional[str] = None
):
self.code = code
self.message = message
self.details = details or {}
self.suggestion = suggestion
super().__init__(self.message)
def to_dict(self) -> dict:
"""转换为错误响应格式"""
error_obj = {
"code": self.code.value,
"message": self.message,
"details": self.details
}
if self.suggestion:
error_obj["suggestion"] = self.suggestion
return error_obj
# 在工具中使用
@app.call_tool()
async def fetch_user_data(user_id: str):
try:
# 参数验证
if not user_id:
raise MCPError(
code=ErrorCode.MISSING_REQUIRED_PARAM,
message="用户ID不能为空",
suggestion="请提供有效的用户ID"
)
# 数据库查询
try:
user = await db.query(f"SELECT * FROM users WHERE id = {user_id}")
except DatabaseConnectionError as e:
raise MCPError(
code=ErrorCode.DATABASE_ERROR,
message="数据库连接失败",
details={"original_error": str(e)},
suggestion="请稍后重试或联系管理员"
)
if not user:
raise MCPError(
code=ErrorCode.RESOURCE_NOT_FOUND,
message=f"未找到用户: {user_id}",
suggestion="请检查用户ID是否正确"
)
return {"content": [{"type": "text", "text": json.dumps(user)}]}
except MCPError as e:
return {
"content": [{
"type": "text",
"text": f"错误 [{e.code.value}]: {e.message}"
}],
"isError": True,
"error": e.to_dict()
}
except Exception as e:
# 未预期的错误
return {
"content": [{
"type": "text",
"text": f"系统错误: {str(e)}"
}],
"isError": True,
"error": {
"code": ErrorCode.INTERNAL_ERROR.value,
"message": "发生内部错误"
}
}
7.2 重试机制
import asyncio
import random
from typing import Callable, TypeVar
from functools import wraps
T = TypeVar('T')
class RetryConfig:
"""重试配置"""
def __init__(
self,
max_attempts: int = 3,
base_delay: float = 1.0,
max_delay: float = 30.0,
exponential_base: float = 2.0,
retryable_exceptions: tuple = (Exception,)
):
self.max_attempts = max_attempts
self.base_delay = base_delay
self.max_delay = max_delay
self.exponential_base = exponential_base
self.retryable_exceptions = retryable_exceptions
def with_retry(config: RetryConfig = None):
"""重试装饰器"""
if config is None:
config = RetryConfig()
def decorator(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
async def wrapper(*args, **kwargs) -> T:
last_exception = None
for attempt in range(1, config.max_attempts + 1):
try:
return await func(*args, **kwargs)
except config.retryable_exceptions as e:
last_exception = e
if attempt == config.max_attempts:
break
# 计算延迟(指数退避 + 抖动)
delay = min(
config.base_delay * (config.exponential_base ** (attempt - 1)),
config.max_delay
)
jitter = random.uniform(0, delay * 0.1) # 10% 抖动
total_delay = delay + jitter
print(f"尝试 {attempt} 失败,{total_delay:.1f}秒后重试...")
await asyncio.sleep(total_delay)
# 所有尝试都失败
raise last_exception
return wrapper
return decorator
# 使用重试
@app.call_tool()
@with_retry(RetryConfig(
max_attempts=3,
base_delay=1.0,
retryable_exceptions=(ConnectionError, TimeoutError)
))
async def call_external_api(endpoint: str):
"""调用外部API(带重试)"""
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(endpoint, timeout=10) as response:
response.raise_for_status()
return await response.json()
@app.call_tool()
@with_retry(RetryConfig(
max_attempts=5,
base_delay=0.5,
exponential_base=1.5
))
async def database_query(sql: str):
"""数据库查询(带重试)"""
# 实现...
💡 提示:只对幂等操作使用重试,避免非幂等操作(如扣款、发送通知)重复执行。
📦 8. 版本管理
8.1 Server版本声明
from mcp.server import Server
from mcp.types import Implementation, ServerCapabilities
# 定义版本信息
SERVER_VERSION = "1.2.3"
SERVER_NAME = "my-mcp-server"
app = Server(
SERVER_NAME,
capabilities=ServerCapabilities(
tools={},
resources={},
prompts={}
)
)
# 在initialize中返回版本信息
@app.on_initialize()
async def on_initialize(params):
return {
"protocolVersion": "2024-11-05",
"serverInfo": {
"name": SERVER_NAME,
"version": SERVER_VERSION
},
"capabilities": app.capabilities
}
8.2 向后兼容策略
from packaging import version
# 支持的客户端版本范围
MIN_CLIENT_VERSION = "1.0.0"
MAX_CLIENT_VERSION = "2.0.0"
# 已弃用功能
DEPRECATED_FEATURES = {
"old_tool_name": {
"deprecated_in": "1.1.0",
"removed_in": "2.0.0",
"replacement": "new_tool_name"
}
}
@app.on_initialize()
async def check_compatibility(params):
"""检查客户端兼容性"""
client_version = params.get("clientInfo", {}).get("version", "0.0.0")
# 版本检查
if version.parse(client_version) < version.parse(MIN_CLIENT_VERSION):
raise ValueError(
f"客户端版本 {client_version} 过低,"
f"最低要求: {MIN_CLIENT_VERSION}"
)
if version.parse(client_version) > version.parse(MAX_CLIENT_VERSION):
# 警告但不阻止
return {
"warning": f"客户端版本 {client_version} 尚未测试,可能不兼容"
}
return {}
@app.call_tool()
async def deprecated_tool(arguments: dict):
"""旧版工具 - 重定向到新版本"""
# 记录弃用警告
print("警告: 使用已弃用的工具,请迁移到 new_tool_name")
# 参数转换
new_arguments = convert_legacy_params(arguments)
# 调用新实现
return await new_tool(new_arguments)
def convert_legacy_params(old_params: dict) -> dict:
"""转换旧版参数到新版"""
mapping = {
"old_field": "new_field",
"legacy_param": "modern_param"
}
return {
mapping.get(k, k): v
for k, v in old_params.items()
}
8.3 API版本控制
# 多版本API支持
@app.call_tool()
async def api_v2_get_data(arguments: dict):
"""V2版本API"""
return await get_data_v2(arguments)
@app.call_tool()
async def api_v1_get_data(arguments: dict):
"""V1版本API(兼容旧客户端)"""
# V1参数 -> V2参数
v2_args = {
"id": arguments.get("item_id"),
"include_meta": arguments.get("with_metadata", False)
}
# 调用V2实现
result = await get_data_v2(v2_args)
# V2结果 -> V1结果格式
return {
"item_id": result["id"],
"item_data": result["data"]
}
🧪 9. 测试策略
9.1 单元测试
# test_server.py
import pytest
from unittest.mock import Mock, AsyncMock, patch
from my_mcp_server import app
@pytest.fixture
def mock_session():
"""模拟MCP会话"""
session = AsyncMock()
session.roots = []
return session
@pytest.fixture
def mock_context(mock_session):
"""模拟请求上下文"""
context = Mock()
context.session = mock_session
return context
@pytest.mark.asyncio
async def test_calculator_add(mock_context):
"""测试计算器加法"""
# 设置上下文
app.request_context = mock_context
# 调用工具
result = await app.call_tool("calculator", {
"operation": "add",
"a": 5,
"b": 3
})
# 验证结果
assert result["content"][0]["text"] == "8"
assert not result.get("isError")
@pytest.mark.asyncio
async def test_calculator_division_by_zero(mock_context):
"""测试除零错误处理"""
app.request_context = mock_context
result = await app.call_tool("calculator", {
"operation": "divide",
"a": 10,
"b": 0
})
# 验证错误处理
assert result["isError"] is True
assert "不能除以零" in result["content"][0]["text"]
@pytest.mark.asyncio
async def test_file_read_with_roots(mock_context, mock_session):
"""测试Roots限制"""
# 设置允许的roots
from mcp.types import Root
mock_session.roots = [
Root(uri="file:///allowed/path", name="allowed")
]
app.request_context = mock_context
# 测试允许的路径
result = await app.call_tool("read_file", {
"path": "/allowed/path/file.txt"
})
# 验证成功...
# 测试禁止的路径
result = await app.call_tool("read_file", {
"path": "/etc/passwd"
})
assert result["isError"] is True
assert "访问被拒绝" in result["content"][0]["text"]
9.2 集成测试
# test_integration.py
import pytest
import asyncio
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@pytest.fixture(scope="module")
async def mcp_client():
"""启动Server并创建客户端连接"""
server_params = StdioServerParameters(
command="python",
args=["-m", "my_mcp_server"],
env=None
)
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
await session.initialize()
yield session
@pytest.mark.asyncio
async def test_tool_list(mcp_client):
"""测试工具列表"""
tools = await mcp_client.list_tools()
tool_names = [tool.name for tool in tools]
assert "calculator" in tool_names
assert "read_file" in tool_names
@pytest.mark.asyncio
async def test_end_to_end_workflow(mcp_client):
"""端到端工作流测试"""
# 1. 列出工具
tools = await mcp_client.list_tools()
# 2. 调用计算器
result = await mcp_client.call_tool("calculator", {
"operation": "multiply",
"a": 6,
"b": 7
})
assert "42" in result.content[0].text
# 3. 读取资源
resources = await mcp_client.list_resources()
# 4. 调用prompt
prompts = await mcp_client.list_prompts()
# 使用pytest-asyncio运行
# pytest test_integration.py -v
9.3 性能测试
# test_performance.py
import pytest
import time
import asyncio
import statistics
@pytest.mark.asyncio
async def test_tool_performance(mcp_client):
"""测试工具响应时间"""
latencies = []
for _ in range(100):
start = time.time()
await mcp_client.call_tool("calculator", {
"operation": "add",
"a": 1,
"b": 2
})
latency = (time.time() - start) * 1000 # 转换为毫秒
latencies.append(latency)
# 统计结果
avg_latency = statistics.mean(latencies)
p95_latency = sorted(latencies)[int(len(latencies) * 0.95)]
print(f"平均延迟: {avg_latency:.2f}ms")
print(f"P95延迟: {p95_latency:.2f}ms")
# 断言性能要求
assert avg_latency < 100 # 平均延迟应小于100ms
assert p95_latency < 200 # P95延迟应小于200ms
@pytest.mark.asyncio
async def test_concurrent_load(mcp_client):
"""并发负载测试"""
async def make_request(i):
start = time.time()
await mcp_client.call_tool("calculator", {
"operation": "add",
"a": i,
"b": i
})
return time.time() - start
# 同时发起50个请求
start = time.time()
results = await asyncio.gather(*[make_request(i) for i in range(50)])
total_time = time.time() - start
print(f"50个并发请求总耗时: {total_time:.2f}s")
print(f"平均每个请求: {total_time/50*1000:.2f}ms")
assert total_time < 5 # 总时间应小于5秒
📊 10. 监控与日志
10.1 结构化日志
import json
import logging
import time
from datetime import datetime
from contextvars import ContextVar
from typing import Any, Optional
# 请求上下文
request_id_var: ContextVar[str] = ContextVar('request_id')
class StructuredLogFormatter(logging.Formatter):
"""结构化日志格式化器"""
def format(self, record: logging.LogRecord) -> str:
log_obj = {
"timestamp": datetime.utcnow().isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
"request_id": request_id_var.get(None),
"module": record.module,
"function": record.funcName,
"line": record.lineno
}
# 添加额外字段
if hasattr(record, 'extra_data'):
log_obj.update(record.extra_data)
# 添加异常信息
if record.exc_info:
log_obj["exception"] = self.formatException(record.exc_info)
return json.dumps(log_obj, ensure_ascii=False)
# 配置日志
handler = logging.StreamHandler()
handler.setFormatter(StructuredLogFormatter())
logger = logging.getLogger("mcp.server")
logger.addHandler(handler)
logger.setLevel(logging.INFO)
class LoggerAdapter:
"""带上下文的日志适配器"""
def __init__(self, base_logger: logging.Logger):
self._logger = base_logger
def _log(self, level: int, msg: str, extra: Optional[dict] = None):
extra_data = {"extra_data": extra} if extra else {}
self._logger.log(level, msg, extra=extra_data)
def info(self, msg: str, **kwargs):
self._log(logging.INFO, msg, kwargs)
def warning(self, msg: str, **kwargs):
self._log(logging.WARNING, msg, kwargs)
def error(self, msg: str, **kwargs):
self._log(logging.ERROR, msg, kwargs)
log = LoggerAdapter(logger)
# 在工具中使用
@app.call_tool()
async def process_data(arguments: dict, request_context):
# 设置请求ID
request_id = str(uuid.uuid4())
request_id_var.set(request_id)
start_time = time.time()
log.info(
"工具调用开始",
tool="process_data",
request_id=request_id,
arguments_size=len(str(arguments))
)
try:
# 执行业务逻辑
result = await do_processing(arguments)
duration = time.time() - start_time
log.info(
"工具调用成功",
tool="process_data",
duration_ms=duration * 1000,
result_size=len(str(result))
)
return result
except Exception as e:
duration = time.time() - start_time
log.error(
"工具调用失败",
tool="process_data",
duration_ms=duration * 1000,
error_type=type(e).__name__,
error_message=str(e)
)
raise
10.2 性能指标收集
from dataclasses import dataclass, field
from typing import Dict, List
from collections import defaultdict
import statistics
@dataclass
class MetricsCollector:
"""指标收集器"""
# 工具调用计数
tool_calls: Dict[str, int] = field(default_factory=lambda: defaultdict(int))
# 响应时间记录(毫秒)
response_times: Dict[str, List[float]] = field(default_factory=lambda: defaultdict(list))
# 错误计数
errors: Dict[str, int] = field(default_factory=lambda: defaultdict(int))
# 资源访问统计
resource_access: Dict[str, int] = field(default_factory=lambda: defaultdict(int))
def record_tool_call(self, tool_name: str, duration_ms: float, success: bool = True):
"""记录工具调用"""
self.tool_calls[tool_name] += 1
self.response_times[tool_name].append(duration_ms)
if not success:
self.errors[tool_name] += 1
def record_resource_access(self, uri: str):
"""记录资源访问"""
self.resource_access[uri] += 1
def get_tool_stats(self, tool_name: str) -> dict:
"""获取工具统计信息"""
times = self.response_times.get(tool_name, [])
if not times:
return {}
return {
"total_calls": self.tool_calls[tool_name],
"error_count": self.errors[tool_name],
"error_rate": self.errors[tool_name] / self.tool_calls[tool_name],
"avg_response_time": statistics.mean(times),
"p95_response_time": sorted(times)[int(len(times) * 0.95)],
"max_response_time": max(times)
}
def get_summary(self) -> dict:
"""获取汇总报告"""
return {
"total_tool_calls": sum(self.tool_calls.values()),
"total_errors": sum(self.errors.values()),
"top_tools": sorted(
self.tool_calls.items(),
key=lambda x: x[1],
reverse=True
)[:5],
"slowest_tools": sorted(
[(name, statistics.mean(times)) for name, times in self.response_times.items()],
key=lambda x: x[1],
reverse=True
)[:5]
}
# 全局指标收集器
metrics = MetricsCollector()
# 在工具中使用
@app.call_tool()
async def monitored_tool(arguments: dict):
start = time.time()
success = True
try:
result = await actual_work(arguments)
return result
except Exception as e:
success = False
raise
finally:
duration = (time.time() - start) * 1000
metrics.record_tool_call("monitored_tool", duration, success)
# 健康检查端点
@app.call_tool()
async def get_health_metrics():
"""获取健康指标(用于监控)"""
return {
"content": [{
"type": "text",
"text": json.dumps(metrics.get_summary(), indent=2)
}]
}
10.3 健康检查
import asyncio
from enum import Enum
from dataclasses import dataclass
from typing import List, Dict
import psutil
class HealthStatus(Enum):
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
@dataclass
class HealthCheck:
name: str
status: HealthStatus
message: str
details: Dict = None
class HealthChecker:
"""健康检查器"""
def __init__(self):
self.checks = []
def add_check(self, name: str, check_func):
"""添加检查项"""
self.checks.append((name, check_func))
async def run_checks(self) -> List[HealthCheck]:
"""运行所有检查"""
results = []
for name, check_func in self.checks:
try:
result = await check_func()
results.append(HealthCheck(name, result["status"], result["message"], result.get("details")))
except Exception as e:
results.append(HealthCheck(name, HealthStatus.UNHEALTHY, str(e)))
return results
async def get_overall_status(self) -> Dict:
"""获取整体健康状态"""
checks = await self.run_checks()
# 确定整体状态
if any(c.status == HealthStatus.UNHEALTHY for c in checks):
overall = HealthStatus.UNHEALTHY
elif any(c.status == HealthStatus.DEGRADED for c in checks):
overall = HealthStatus.DEGRADED
else:
overall = HealthStatus.HEALTHY
return {
"status": overall.value,
"timestamp": datetime.utcnow().isoformat(),
"checks": [
{
"name": c.name,
"status": c.status.value,
"message": c.message,
"details": c.details
}
for c in checks
]
}
# 创建健康检查器
health_checker = HealthChecker()
# 系统资源检查
async def check_system_resources():
"""检查系统资源"""
cpu_percent = psutil.cpu_percent(interval=1)
memory = psutil.virtual_memory()
if cpu_percent > 90 or memory.percent > 90:
status = HealthStatus.UNHEALTHY
message = f"资源紧张: CPU {cpu_percent}%, 内存 {memory.percent}%"
elif cpu_percent > 70 or memory.percent > 70:
status = HealthStatus.DEGRADED
message = f"资源使用较高: CPU {cpu_percent}%, 内存 {memory.percent}%"
else:
status = HealthStatus.HEALTHY
message = f"资源正常: CPU {cpu_percent}%, 内存 {memory.percent}%"
return {
"status": status,
"message": message,
"details": {"cpu": cpu_percent, "memory": memory.percent}
}
# 数据库连接检查
async def check_database():
"""检查数据库连接"""
try:
# 尝试简单查询
await db.execute("SELECT 1")
return {
"status": HealthStatus.HEALTHY,
"message": "数据库连接正常"
}
except Exception as e:
return {
"status": HealthStatus.UNHEALTHY,
"message": f"数据库连接失败: {e}"
}
# 注册检查
health_checker.add_check("system_resources", check_system_resources)
health_checker.add_check("database", check_database)
@app.call_tool()
async def health_check():
"""健康检查端点"""
status = await health_checker.get_overall_status()
return {
"content": [{
"type": "text",
"text": json.dumps(status, indent=2)
}],
"isError": status["status"] == "unhealthy"
}
📋 本章总结
本章深入介绍了MCP的高级特性和生产环境最佳实践:
- 进度通知:让用户了解长时间任务的执行状态
- 取消请求:支持用户主动取消任务
- Sampling:Server可以请求Client的LLM生成内容
- Roots:限制Server的文件系统访问范围
- 安全实践:输入验证、权限控制、沙箱执行、敏感信息保护
- 性能优化:连接复用、批量操作、缓存策略
- 错误处理:有意义的错误信息、错误分类、重试机制
- 版本管理:Server版本声明、向后兼容
- 测试策略:单元测试、集成测试、性能测试
- 监控日志:结构化日志、指标收集、健康检查
🎉 恭喜!你已经掌握了MCP的所有核心知识。在下一章中,我们将通过3个完整的实战项目来综合运用这些知识。