第七章

高级特性与最佳实践

深入掌握MCP的高级功能,学习如何构建生产级的MCP Server

📢 1. 进度通知(Progress)

当Server执行耗时较长的任务时,进度通知机制可以让用户了解任务的执行状态,提升用户体验。

1.1 为什么需要进度通知?

  • 大文件处理(上传、下载、分析)
  • 复杂数据分析任务
  • 批量操作(处理大量记录)
  • 外部API调用(可能有延迟)

1.2 实现进度通知

MCP协议支持通过notifications/progress发送进度更新:

from mcp.server import Server
from mcp.types import ProgressNotification
import asyncio

app = Server("progress-demo")

@app.call_tool()
async def process_large_file(file_path: str, progress_token: str = None):
    """处理大文件,带进度通知"""
    # 模拟大文件处理
    total_steps = 100
    
    for i in range(total_steps):
        # 执行实际工作
        await asyncio.sleep(0.1)
        
        # 发送进度通知(如果有progress_token)
        if progress_token:
            await app.request_context.session.send_notification(
                "notifications/progress",
                {
                    "progressToken": progress_token,
                    "progress": i + 1,
                    "total": total_steps,
                    "message": f"正在处理第 {i+1}/{total_steps} 部分..."
                }
            )
    
    return {
        "content": [{
            "type": "text",
            "text": "✅ 文件处理完成!"
        }]
    }
💡 提示:progress_token由Client在调用时提供,Server需要检查是否存在再发送通知。

1.3 Client端接收进度

from mcp import ClientSession

async def call_with_progress(session: ClientSession):
    # 生成唯一的progress token
    progress_token = str(uuid.uuid4())
    
    # 设置进度回调
    def on_progress(progress: int, total: int, message: str):
        percentage = (progress / total) * 100
        print(f"进度: {percentage:.1f}% - {message}")
    
    # 调用工具时传入progress token
    result = await session.call_tool(
        "process_large_file",
        {"file_path": "/path/to/large/file.zip"},
        progress_token=progress_token
    )
    
    return result

🛑 2. 取消请求(Cancellation)

允许用户取消正在执行的长时间任务,避免资源浪费。

2.1 Server端支持取消

import asyncio
from mcp.server import Server
from mcp.types import CancelledNotification

app = Server("cancellable-demo")

# 存储正在运行的任务
running_tasks = {}

@app.call_tool()
async def long_running_task(task_id: str, duration: int):
    """可取消的长时任务"""
    cancel_event = asyncio.Event()
    running_tasks[task_id] = cancel_event
    
    try:
        for i in range(duration):
            # 检查是否被取消
            if cancel_event.is_set():
                return {
                    "content": [{
                        "type": "text",
                        "text": f"❌ 任务在 {i} 秒时被用户取消"
                    }],
                    "isError": False
                }
            
            await asyncio.sleep(1)
        
        return {
            "content": [{
                "type": "text",
                "text": "✅ 任务成功完成!"
            }]
        }
    finally:
        # 清理任务记录
        running_tasks.pop(task_id, None)

@app.on_notification("cancelled")
async def handle_cancel(notification: CancelledNotification):
    """处理取消通知"""
    request_id = notification.params["requestId"]
    
    # 查找并取消对应的任务
    if request_id in running_tasks:
        running_tasks[request_id].set()
        print(f"任务 {request_id} 已标记为取消")

2.2 使用asyncio.Task的取消机制

@app.call_tool()
async def cancellable_computation(data: list):
    """使用asyncio原生取消机制"""
    task = asyncio.current_task()
    
    results = []
    for i, item in enumerate(data):
        # 检查任务是否被取消
        if task.cancelled():
            raise asyncio.CancelledError("任务被用户取消")
        
        # 模拟计算
        result = await expensive_operation(item)
        results.append(result)
        
        # 定期检查取消状态
        await asyncio.sleep(0)
    
    return {"content": [{"type": "text", "text": str(results)}]}
⚠️ 警告:取消请求只是发送信号,Server需要主动检查并响应。确保在关键操作前检查取消状态。

🎲 3. 采样/补全(Sampling)

Sampling是MCP的一个强大特性,允许Server请求Client的LLM生成内容。这使得Server可以利用AI能力来增强功能。

3.1 Sampling使用场景

  • 智能代码生成:根据上下文生成代码片段
  • 内容优化:改进用户输入的文本质量
  • 智能分析:分析数据并提供AI见解
  • 对话增强:生成自然语言响应

3.2 Server请求Sampling

from mcp.server import Server
from mcp.types import SamplingMessage

app = Server("sampling-demo")

@app.call_tool()
async def generate_documentation(code: str):
    """为代码生成文档注释"""
    # 构建sampling请求
    messages = [
        SamplingMessage(
            role="system",
            content={
                "type": "text",
                "text": "你是一个专业的代码文档工程师。为给定的代码生成清晰、详细的文档注释。"
            }
        ),
        SamplingMessage(
            role="user",
            content={
                "type": "text", 
                "text": f"请为以下代码生成文档注释:\n\n```python\n{code}\n```"
            }
        )
    ]
    
    # 请求LLM生成内容
    response = await app.request_context.session.send_request(
        "sampling/createMessage",
        {
            "messages": messages,
            "modelPreferences": {
                "hints": ["claude-3-5-sonnet"],
                "intelligencePriority": 0.8,
                "speedPriority": 0.3
            },
            "maxTokens": 2000
        }
    )
    
    return {
        "content": [{
            "type": "text",
            "text": f"生成的文档:\n\n{response['content']['text']}"
        }]
    }

3.3 高级Sampling示例:智能数据分析

@app.call_tool()
async def analyze_sales_data(csv_data: str):
    """分析销售数据并生成AI洞察报告"""
    # 先进行基础数据处理
    import pandas as pd
    from io import StringIO
    
    df = pd.read_csv(StringIO(csv_data))
    basic_stats = df.describe().to_string()
    
    # 请求AI进行深度分析
    messages = [
        SamplingMessage(
            role="system",
            content={
                "type": "text",
                "text": "你是数据分析师,擅长从销售数据中发现趋势和机会。"
            }
        ),
        SamplingMessage(
            role="user",
            content={
                "type": "text",
                "text": f"""请分析以下销售数据,提供:
1. 关键发现(3-5条)
2. 趋势分析
3. 改进建议

基础统计数据:
{basic_stats}

原始数据摘要:
- 总记录数: {len(df)}
- 时间范围: {df['date'].min()} 至 {df['date'].max()}
- 总销售额: ${df['amount'].sum():,.2f}
"""
            }
        )
    ]
    
    analysis = await app.request_context.session.send_request(
        "sampling/createMessage",
        {
            "messages": messages,
            "maxTokens": 3000
        }
    )
    
    return {
        "content": [{
            "type": "text",
            "text": analysis["content"]["text"]
        }],
        "metadata": {
            "data_points": len(df),
            "total_revenue": float(df['amount'].sum())
        }
    }
✅ 最佳实践:Sampling应该作为增强功能,而不是必需功能。Server应该在没有Sampling的情况下也能正常工作。

📁 4. Roots - 文件系统安全访问

Roots机制允许Client限制Server可以访问的文件系统范围,是重要的安全防护措施。

4.1 什么是Roots?

Roots是Client声明的一组目录路径,Server只能访问这些目录及其子目录内的文件:

{
  "roots": [
    {
      "uri": "file:///home/user/projects",
      "name": "项目目录"
    },
    {
      "uri": "file:///home/user/documents",
      "name": "文档目录"
    }
  ]
}

4.2 Server端读取Roots

from mcp.server import Server
from pathlib import Path
import os

app = Server("roots-demo")

def get_roots() -> list[Path]:
    """获取Client声明的roots"""
    roots = app.request_context.session.roots
    return [Path(root.uri.replace("file://", "")) for root in roots]

def is_path_allowed(target_path: str) -> bool:
    """检查路径是否在允许的roots范围内"""
    target = Path(target_path).resolve()
    allowed_roots = get_roots()
    
    for root in allowed_roots:
        try:
            target.relative_to(root)
            return True
        except ValueError:
            continue
    return False

@app.call_tool()
async def safe_read_file(file_path: str):
    """安全地读取文件(受roots限制)"""
    # 安全检查
    if not is_path_allowed(file_path):
        return {
            "content": [{
                "type": "text",
                "text": f"❌ 访问被拒绝:{file_path} 不在允许的roots范围内"
            }],
            "isError": True
        }
    
    # 读取文件
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        return {
            "content": [{
                "type": "text",
                "text": content
            }]
        }
    except Exception as e:
        return {
            "content": [{
                "type": "text",
                "text": f"❌ 读取失败: {str(e)}"
            }],
            "isError": True
        }

4.3 监听Roots变化

@app.on_notification("roots/list_changed")
async def on_roots_changed():
    """当roots发生变化时重新加载"""
    new_roots = await app.request_context.session.send_request(
        "roots/list",
        {}
    )
    print(f"Roots已更新: {new_roots}")
    # 清理缓存或重新验证路径

🔒 5. 安全最佳实践

5.1 输入验证

from pydantic import BaseModel, validator, Field
import re

class FileOperationInput(BaseModel):
    path: str = Field(..., description="文件路径")
    content: str = Field(default="", description="文件内容")
    
    @validator('path')
    def validate_path(cls, v):
        # 防止路径遍历攻击
        if '..' in v or '~' in v:
            raise ValueError("路径包含非法字符")
        
        # 只允许特定扩展名
        allowed_extensions = {'.txt', '.md', '.py', '.json'}
        ext = Path(v).suffix.lower()
        if ext not in allowed_extensions:
            raise ValueError(f"不支持的文件类型: {ext}")
        
        return v
    
    @validator('content')
    def validate_content(cls, v):
        # 限制文件大小
        max_size = 1024 * 1024  # 1MB
        if len(v.encode('utf-8')) > max_size:
            raise ValueError("文件内容超过最大限制(1MB)")
        
        # 检查危险内容
        dangerous_patterns = [
            r'eval\s*\(',
            r'exec\s*\(',
            r'subprocess\.call',
            r'os\.system'
        ]
        for pattern in dangerous_patterns:
            if re.search(pattern, v, re.IGNORECASE):
                raise ValueError("内容包含潜在危险代码")
        
        return v

# 在工具中使用验证
@app.call_tool()
async def write_file(arguments: dict):
    try:
        validated = FileOperationInput(**arguments)
        # 继续处理...
    except ValueError as e:
        return {
            "content": [{"type": "text", "text": f"输入验证失败: {e}"}],
            "isError": True
        }

5.2 权限控制

from enum import Enum
from functools import wraps

class PermissionLevel(Enum):
    READ = "read"
    WRITE = "write"
    EXECUTE = "execute"
    ADMIN = "admin"

# 权限检查装饰器
def require_permission(level: PermissionLevel):
    def decorator(func):
        @wraps(func)
        async def wrapper(arguments: dict, **kwargs):
            # 获取用户权限(从Client的metadata或session)
            user_level = kwargs.get('permissions', {}).get('level', PermissionLevel.READ)
            
            # 权限检查逻辑
            permission_order = [PermissionLevel.READ, PermissionLevel.WRITE, 
                               PermissionLevel.EXECUTE, PermissionLevel.ADMIN]
            
            if permission_order.index(user_level) < permission_order.index(level):
                return {
                    "content": [{
                        "type": "text",
                        "text": f"❌ 权限不足,需要 {level.value} 权限"
                    }],
                    "isError": True
                }
            
            return await func(arguments, **kwargs)
        return wrapper
    return decorator

@app.call_tool()
@require_permission(PermissionLevel.WRITE)
async def delete_file(arguments: dict, **kwargs):
    """删除文件 - 需要写权限"""
    # 实现...

5.3 沙箱执行

import subprocess
import tempfile
import os
from pathlib import Path

async def sandbox_execute(code: str, timeout: int = 30) -> dict:
    """在沙箱环境中执行代码"""
    
    # 创建临时目录作为沙箱
    with tempfile.TemporaryDirectory() as sandbox_dir:
        # 写入代码文件
        code_file = Path(sandbox_dir) / "script.py"
        code_file.write_text(code)
        
        # 使用受限环境执行
        try:
            result = subprocess.run(
                ["python", "-c", f"""
import sys
sys.path = []  # 清除模块搜索路径
exec(open('{code_file}').read())
"""],
                capture_output=True,
                text=True,
                timeout=timeout,
                cwd=sandbox_dir,
                # 环境变量限制
                env={"PYTHONPATH": "", "PATH": "/usr/bin"}
            )
            
            return {
                "stdout": result.stdout,
                "stderr": result.stderr,
                "returncode": result.returncode
            }
        except subprocess.TimeoutExpired:
            return {"error": "执行超时"}
        except Exception as e:
            return {"error": str(e)}

# 更安全的方案:使用Docker
async def docker_execute(code: str, image: str = "python:3.11-slim") -> dict:
    """在Docker容器中执行代码"""
    with tempfile.TemporaryDirectory() as tmpdir:
        code_file = Path(tmpdir) / "script.py"
        code_file.write_text(code)
        
        result = subprocess.run(
            [
                "docker", "run", "--rm",
                "-v", f"{tmpdir}:/workspace:ro",  # 只读挂载
                "--network", "none",  # 禁用网络
                "--memory", "128m",  # 内存限制
                "--cpus", "0.5",  # CPU限制
                "--timeout", "30",
                image,
                "python", "/workspace/script.py"
            ],
            capture_output=True,
            text=True,
            timeout=35
        )
        
        return {
            "stdout": result.stdout,
            "stderr": result.stderr,
            "returncode": result.returncode
        }

5.4 敏感信息保护

import re
from typing import Any

class SensitiveDataFilter:
    """敏感数据过滤器"""
    
    # 正则模式
    PATTERNS = {
        'api_key': r'[a-zA-Z0-9_-]{32,}',  # API密钥
        'password': r'password[\s]*[=:]+[\s]*[^\s]+',
        'secret': r'secret[\s]*[=:]+[\s]*[^\s]+',
        'token': r'[a-zA-Z0-9_-]{20,}\.[a-zA-Z0-9_-]{20,}',  # JWT
        'email': r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}',
        'credit_card': r'\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}',
    }
    
    REPLACEMENTS = {
        'api_key': '',
        'password': '',
        'secret': '',
        'token': '',
        'email': lambda m: f"{m.group(0).split('@')[0][:2]}***@{m.group(0).split('@')[1]}",
        'credit_card': '',
    }
    
    @classmethod
    def filter_text(cls, text: str) -> str:
        """过滤文本中的敏感信息"""
        for key, pattern in cls.PATTERNS.items():
            replacement = cls.REPLACEMENTS[key]
            if callable(replacement):
                text = re.sub(pattern, replacement, text)
            else:
                text = re.sub(pattern, replacement, text)
        return text
    
    @classmethod
    def filter_dict(cls, data: dict) -> dict:
        """递归过滤字典中的敏感信息"""
        if isinstance(data, dict):
            return {
                k: cls.filter_dict(v) if isinstance(v, (dict, list)) 
                   else cls.filter_text(str(v)) if isinstance(v, str) else v
                for k, v in data.items()
            }
        elif isinstance(data, list):
            return [cls.filter_dict(item) for item in data]
        return data

# 在Server中使用
@app.call_tool()
async def process_data(data: dict):
    """处理数据(自动过滤敏感信息)"""
    # 处理前记录日志(已过滤)
    safe_data = SensitiveDataFilter.filter_dict(data)
    print(f"处理数据: {safe_data}")
    
    # 实际处理使用原始数据
    result = await actual_processing(data)
    
    # 返回结果也进行过滤
    return {
        "content": [{
            "type": "text",
            "text": SensitiveDataFilter.filter_text(str(result))
        }]
    }

⚡ 6. 性能优化

6.1 连接复用

import asyncio
from contextlib import asynccontextmanager
from mcp.client import ClientSession

class MCPConnectionPool:
    """MCP连接池"""
    
    def __init__(self, max_connections: int = 10):
        self.max_connections = max_connections
        self._pool = asyncio.Queue()
        self._used = set()
        self._lock = asyncio.Lock()
    
    async def initialize(self, server_params):
        """初始化连接池"""
        for _ in range(self.max_connections):
            session = await self._create_session(server_params)
            await self._pool.put(session)
    
    @asynccontextmanager
    async def acquire(self):
        """获取连接"""
        session = await self._pool.get()
        self._used.add(id(session))
        try:
            yield session
        finally:
            self._used.discard(id(session))
            await self._pool.put(session)
    
    async def _create_session(self, params):
        """创建新会话"""
        # 实际创建逻辑
        pass

# 使用连接池
pool = MCPConnectionPool(max_connections=5)
await pool.initialize(server_params)

async with pool.acquire() as session:
    result = await session.call_tool("some_tool", {})

6.2 批量操作

from typing import List
import asyncio

@app.call_tool()
async def batch_process(items: List[str], batch_size: int = 10):
    """批量处理项目"""
    results = []
    
    # 分批处理
    for i in range(0, len(items), batch_size):
        batch = items[i:i + batch_size]
        
        # 并行处理批次内的项目
        batch_results = await asyncio.gather(
            *[process_single(item) for item in batch],
            return_exceptions=True
        )
        
        results.extend(batch_results)
        
        # 发送进度通知
        await send_progress(i + len(batch), len(items))
    
    return {
        "content": [{
            "type": "text",
            "text": f"处理完成: {len(results)}/{len(items)}"
        }],
        "metadata": {"results": results}
    }

@app.call_tool()
async def parallel_queries(queries: List[str]):
    """并行执行多个查询"""
    semaphore = asyncio.Semaphore(5)  # 限制并发数
    
    async def limited_query(query: str):
        async with semaphore:
            return await execute_query(query)
    
    results = await asyncio.gather(
        *[limited_query(q) for q in queries],
        return_exceptions=True
    )
    
    # 处理结果和异常
    successful = [r for r in results if not isinstance(r, Exception)]
    failed = [str(r) for r in results if isinstance(r, Exception)]
    
    return {
        "content": [{
            "type": "text",
            "text": f"成功: {len(successful)}, 失败: {len(failed)}"
        }]
    }

6.3 缓存策略

import functools
import time
from typing import Any, Optional
import hashlib

class CacheManager:
    """缓存管理器"""
    
    def __init__(self, default_ttl: int = 300):
        self._cache = {}
        self._ttl = default_ttl
    
    def _make_key(self, *args, **kwargs) -> str:
        """生成缓存键"""
        key_data = f"{args}:{sorted(kwargs.items())}"
        return hashlib.md5(key_data.encode()).hexdigest()
    
    def get(self, key: str) -> Optional[Any]:
        """获取缓存值"""
        if key in self._cache:
            value, expire_time = self._cache[key]
            if time.time() < expire_time:
                return value
            else:
                del self._cache[key]
        return None
    
    def set(self, key: str, value: Any, ttl: Optional[int] = None):
        """设置缓存值"""
        expire_time = time.time() + (ttl or self._ttl)
        self._cache[key] = (value, expire_time)
    
    def cached(self, ttl: Optional[int] = None):
        """缓存装饰器"""
        def decorator(func):
            @functools.wraps(func)
            async def wrapper(*args, **kwargs):
                # 生成缓存键
                cache_key = f"{func.__name__}:{self._make_key(*args, **kwargs)}"
                
                # 尝试从缓存获取
                cached_value = self.get(cache_key)
                if cached_value is not None:
                    return cached_value
                
                # 执行函数
                result = await func(*args, **kwargs)
                
                # 缓存结果
                self.set(cache_key, result, ttl)
                return result
            return wrapper
        return decorator

# 使用缓存
cache = CacheManager(default_ttl=60)

@app.call_tool()
@cache.cached(ttl=300)  # 缓存5分钟
async def get_schema(database: str):
    """获取数据库结构(缓存结果)"""
    # 昂贵的数据库查询
    schema = await fetch_database_schema(database)
    return schema

@app.read_resource()
@cache.cached(ttl=60)
async def read_file_resource(uri: str):
    """读取文件资源(缓存1分钟)"""
    file_path = uri.replace("file://", "")
    with open(file_path, 'r') as f:
        return f.read()

🛠️ 7. 错误处理最佳实践

7.1 有意义的错误信息

from enum import Enum
from typing import Optional
import json

class ErrorCode(Enum):
    """错误代码枚举"""
    # 系统错误 1xxx
    INTERNAL_ERROR = "1000"
    SERVICE_UNAVAILABLE = "1001"
    TIMEOUT = "1002"
    
    # 请求错误 2xxx
    INVALID_PARAMS = "2000"
    MISSING_REQUIRED_PARAM = "2001"
    INVALID_FORMAT = "2002"
    
    # 业务错误 3xxx
    RESOURCE_NOT_FOUND = "3000"
    PERMISSION_DENIED = "3001"
    RATE_LIMITED = "3002"
    
    # 外部错误 4xxx
    EXTERNAL_API_ERROR = "4000"
    DATABASE_ERROR = "4001"

class MCPError(Exception):
    """MCP标准错误"""
    
    def __init__(
        self, 
        code: ErrorCode, 
        message: str, 
        details: Optional[dict] = None,
        suggestion: Optional[str] = None
    ):
        self.code = code
        self.message = message
        self.details = details or {}
        self.suggestion = suggestion
        super().__init__(self.message)
    
    def to_dict(self) -> dict:
        """转换为错误响应格式"""
        error_obj = {
            "code": self.code.value,
            "message": self.message,
            "details": self.details
        }
        if self.suggestion:
            error_obj["suggestion"] = self.suggestion
        return error_obj

# 在工具中使用
@app.call_tool()
async def fetch_user_data(user_id: str):
    try:
        # 参数验证
        if not user_id:
            raise MCPError(
                code=ErrorCode.MISSING_REQUIRED_PARAM,
                message="用户ID不能为空",
                suggestion="请提供有效的用户ID"
            )
        
        # 数据库查询
        try:
            user = await db.query(f"SELECT * FROM users WHERE id = {user_id}")
        except DatabaseConnectionError as e:
            raise MCPError(
                code=ErrorCode.DATABASE_ERROR,
                message="数据库连接失败",
                details={"original_error": str(e)},
                suggestion="请稍后重试或联系管理员"
            )
        
        if not user:
            raise MCPError(
                code=ErrorCode.RESOURCE_NOT_FOUND,
                message=f"未找到用户: {user_id}",
                suggestion="请检查用户ID是否正确"
            )
        
        return {"content": [{"type": "text", "text": json.dumps(user)}]}
        
    except MCPError as e:
        return {
            "content": [{
                "type": "text",
                "text": f"错误 [{e.code.value}]: {e.message}"
            }],
            "isError": True,
            "error": e.to_dict()
        }
    except Exception as e:
        # 未预期的错误
        return {
            "content": [{
                "type": "text",
                "text": f"系统错误: {str(e)}"
            }],
            "isError": True,
            "error": {
                "code": ErrorCode.INTERNAL_ERROR.value,
                "message": "发生内部错误"
            }
        }

7.2 重试机制

import asyncio
import random
from typing import Callable, TypeVar
from functools import wraps

T = TypeVar('T')

class RetryConfig:
    """重试配置"""
    def __init__(
        self,
        max_attempts: int = 3,
        base_delay: float = 1.0,
        max_delay: float = 30.0,
        exponential_base: float = 2.0,
        retryable_exceptions: tuple = (Exception,)
    ):
        self.max_attempts = max_attempts
        self.base_delay = base_delay
        self.max_delay = max_delay
        self.exponential_base = exponential_base
        self.retryable_exceptions = retryable_exceptions

def with_retry(config: RetryConfig = None):
    """重试装饰器"""
    if config is None:
        config = RetryConfig()
    
    def decorator(func: Callable[..., T]) -> Callable[..., T]:
        @wraps(func)
        async def wrapper(*args, **kwargs) -> T:
            last_exception = None
            
            for attempt in range(1, config.max_attempts + 1):
                try:
                    return await func(*args, **kwargs)
                except config.retryable_exceptions as e:
                    last_exception = e
                    
                    if attempt == config.max_attempts:
                        break
                    
                    # 计算延迟(指数退避 + 抖动)
                    delay = min(
                        config.base_delay * (config.exponential_base ** (attempt - 1)),
                        config.max_delay
                    )
                    jitter = random.uniform(0, delay * 0.1)  # 10% 抖动
                    total_delay = delay + jitter
                    
                    print(f"尝试 {attempt} 失败,{total_delay:.1f}秒后重试...")
                    await asyncio.sleep(total_delay)
            
            # 所有尝试都失败
            raise last_exception
        
        return wrapper
    return decorator

# 使用重试
@app.call_tool()
@with_retry(RetryConfig(
    max_attempts=3,
    base_delay=1.0,
    retryable_exceptions=(ConnectionError, TimeoutError)
))
async def call_external_api(endpoint: str):
    """调用外部API(带重试)"""
    import aiohttp
    
    async with aiohttp.ClientSession() as session:
        async with session.get(endpoint, timeout=10) as response:
            response.raise_for_status()
            return await response.json()

@app.call_tool()
@with_retry(RetryConfig(
    max_attempts=5,
    base_delay=0.5,
    exponential_base=1.5
))
async def database_query(sql: str):
    """数据库查询(带重试)"""
    # 实现...
💡 提示:只对幂等操作使用重试,避免非幂等操作(如扣款、发送通知)重复执行。

📦 8. 版本管理

8.1 Server版本声明

from mcp.server import Server
from mcp.types import Implementation, ServerCapabilities

# 定义版本信息
SERVER_VERSION = "1.2.3"
SERVER_NAME = "my-mcp-server"

app = Server(
    SERVER_NAME,
    capabilities=ServerCapabilities(
        tools={},
        resources={},
        prompts={}
    )
)

# 在initialize中返回版本信息
@app.on_initialize()
async def on_initialize(params):
    return {
        "protocolVersion": "2024-11-05",
        "serverInfo": {
            "name": SERVER_NAME,
            "version": SERVER_VERSION
        },
        "capabilities": app.capabilities
    }

8.2 向后兼容策略

from packaging import version

# 支持的客户端版本范围
MIN_CLIENT_VERSION = "1.0.0"
MAX_CLIENT_VERSION = "2.0.0"

# 已弃用功能
DEPRECATED_FEATURES = {
    "old_tool_name": {
        "deprecated_in": "1.1.0",
        "removed_in": "2.0.0",
        "replacement": "new_tool_name"
    }
}

@app.on_initialize()
async def check_compatibility(params):
    """检查客户端兼容性"""
    client_version = params.get("clientInfo", {}).get("version", "0.0.0")
    
    # 版本检查
    if version.parse(client_version) < version.parse(MIN_CLIENT_VERSION):
        raise ValueError(
            f"客户端版本 {client_version} 过低,"
            f"最低要求: {MIN_CLIENT_VERSION}"
        )
    
    if version.parse(client_version) > version.parse(MAX_CLIENT_VERSION):
        # 警告但不阻止
        return {
            "warning": f"客户端版本 {client_version} 尚未测试,可能不兼容"
        }
    
    return {}

@app.call_tool()
async def deprecated_tool(arguments: dict):
    """旧版工具 - 重定向到新版本"""
    # 记录弃用警告
    print("警告: 使用已弃用的工具,请迁移到 new_tool_name")
    
    # 参数转换
    new_arguments = convert_legacy_params(arguments)
    
    # 调用新实现
    return await new_tool(new_arguments)

def convert_legacy_params(old_params: dict) -> dict:
    """转换旧版参数到新版"""
    mapping = {
        "old_field": "new_field",
        "legacy_param": "modern_param"
    }
    
    return {
        mapping.get(k, k): v 
        for k, v in old_params.items()
    }

8.3 API版本控制

# 多版本API支持
@app.call_tool()
async def api_v2_get_data(arguments: dict):
    """V2版本API"""
    return await get_data_v2(arguments)

@app.call_tool()
async def api_v1_get_data(arguments: dict):
    """V1版本API(兼容旧客户端)"""
    # V1参数 -> V2参数
    v2_args = {
        "id": arguments.get("item_id"),
        "include_meta": arguments.get("with_metadata", False)
    }
    
    # 调用V2实现
    result = await get_data_v2(v2_args)
    
    # V2结果 -> V1结果格式
    return {
        "item_id": result["id"],
        "item_data": result["data"]
    }

🧪 9. 测试策略

9.1 单元测试

# test_server.py
import pytest
from unittest.mock import Mock, AsyncMock, patch
from my_mcp_server import app

@pytest.fixture
def mock_session():
    """模拟MCP会话"""
    session = AsyncMock()
    session.roots = []
    return session

@pytest.fixture
def mock_context(mock_session):
    """模拟请求上下文"""
    context = Mock()
    context.session = mock_session
    return context

@pytest.mark.asyncio
async def test_calculator_add(mock_context):
    """测试计算器加法"""
    # 设置上下文
    app.request_context = mock_context
    
    # 调用工具
    result = await app.call_tool("calculator", {
        "operation": "add",
        "a": 5,
        "b": 3
    })
    
    # 验证结果
    assert result["content"][0]["text"] == "8"
    assert not result.get("isError")

@pytest.mark.asyncio
async def test_calculator_division_by_zero(mock_context):
    """测试除零错误处理"""
    app.request_context = mock_context
    
    result = await app.call_tool("calculator", {
        "operation": "divide",
        "a": 10,
        "b": 0
    })
    
    # 验证错误处理
    assert result["isError"] is True
    assert "不能除以零" in result["content"][0]["text"]

@pytest.mark.asyncio
async def test_file_read_with_roots(mock_context, mock_session):
    """测试Roots限制"""
    # 设置允许的roots
    from mcp.types import Root
    mock_session.roots = [
        Root(uri="file:///allowed/path", name="allowed")
    ]
    app.request_context = mock_context
    
    # 测试允许的路径
    result = await app.call_tool("read_file", {
        "path": "/allowed/path/file.txt"
    })
    # 验证成功...
    
    # 测试禁止的路径
    result = await app.call_tool("read_file", {
        "path": "/etc/passwd"
    })
    assert result["isError"] is True
    assert "访问被拒绝" in result["content"][0]["text"]

9.2 集成测试

# test_integration.py
import pytest
import asyncio
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

@pytest.fixture(scope="module")
async def mcp_client():
    """启动Server并创建客户端连接"""
    server_params = StdioServerParameters(
        command="python",
        args=["-m", "my_mcp_server"],
        env=None
    )
    
    async with stdio_client(server_params) as (read, write):
        async with ClientSession(read, write) as session:
            await session.initialize()
            yield session

@pytest.mark.asyncio
async def test_tool_list(mcp_client):
    """测试工具列表"""
    tools = await mcp_client.list_tools()
    tool_names = [tool.name for tool in tools]
    
    assert "calculator" in tool_names
    assert "read_file" in tool_names

@pytest.mark.asyncio
async def test_end_to_end_workflow(mcp_client):
    """端到端工作流测试"""
    # 1. 列出工具
    tools = await mcp_client.list_tools()
    
    # 2. 调用计算器
    result = await mcp_client.call_tool("calculator", {
        "operation": "multiply",
        "a": 6,
        "b": 7
    })
    assert "42" in result.content[0].text
    
    # 3. 读取资源
    resources = await mcp_client.list_resources()
    
    # 4. 调用prompt
    prompts = await mcp_client.list_prompts()

# 使用pytest-asyncio运行
# pytest test_integration.py -v

9.3 性能测试

# test_performance.py
import pytest
import time
import asyncio
import statistics

@pytest.mark.asyncio
async def test_tool_performance(mcp_client):
    """测试工具响应时间"""
    latencies = []
    
    for _ in range(100):
        start = time.time()
        await mcp_client.call_tool("calculator", {
            "operation": "add",
            "a": 1,
            "b": 2
        })
        latency = (time.time() - start) * 1000  # 转换为毫秒
        latencies.append(latency)
    
    # 统计结果
    avg_latency = statistics.mean(latencies)
    p95_latency = sorted(latencies)[int(len(latencies) * 0.95)]
    
    print(f"平均延迟: {avg_latency:.2f}ms")
    print(f"P95延迟: {p95_latency:.2f}ms")
    
    # 断言性能要求
    assert avg_latency < 100  # 平均延迟应小于100ms
    assert p95_latency < 200  # P95延迟应小于200ms

@pytest.mark.asyncio
async def test_concurrent_load(mcp_client):
    """并发负载测试"""
    async def make_request(i):
        start = time.time()
        await mcp_client.call_tool("calculator", {
            "operation": "add",
            "a": i,
            "b": i
        })
        return time.time() - start
    
    # 同时发起50个请求
    start = time.time()
    results = await asyncio.gather(*[make_request(i) for i in range(50)])
    total_time = time.time() - start
    
    print(f"50个并发请求总耗时: {total_time:.2f}s")
    print(f"平均每个请求: {total_time/50*1000:.2f}ms")
    
    assert total_time < 5  # 总时间应小于5秒

📊 10. 监控与日志

10.1 结构化日志

import json
import logging
import time
from datetime import datetime
from contextvars import ContextVar
from typing import Any, Optional

# 请求上下文
request_id_var: ContextVar[str] = ContextVar('request_id')

class StructuredLogFormatter(logging.Formatter):
    """结构化日志格式化器"""
    
    def format(self, record: logging.LogRecord) -> str:
        log_obj = {
            "timestamp": datetime.utcnow().isoformat(),
            "level": record.levelname,
            "logger": record.name,
            "message": record.getMessage(),
            "request_id": request_id_var.get(None),
            "module": record.module,
            "function": record.funcName,
            "line": record.lineno
        }
        
        # 添加额外字段
        if hasattr(record, 'extra_data'):
            log_obj.update(record.extra_data)
        
        # 添加异常信息
        if record.exc_info:
            log_obj["exception"] = self.formatException(record.exc_info)
        
        return json.dumps(log_obj, ensure_ascii=False)

# 配置日志
handler = logging.StreamHandler()
handler.setFormatter(StructuredLogFormatter())
logger = logging.getLogger("mcp.server")
logger.addHandler(handler)
logger.setLevel(logging.INFO)

class LoggerAdapter:
    """带上下文的日志适配器"""
    
    def __init__(self, base_logger: logging.Logger):
        self._logger = base_logger
    
    def _log(self, level: int, msg: str, extra: Optional[dict] = None):
        extra_data = {"extra_data": extra} if extra else {}
        self._logger.log(level, msg, extra=extra_data)
    
    def info(self, msg: str, **kwargs):
        self._log(logging.INFO, msg, kwargs)
    
    def warning(self, msg: str, **kwargs):
        self._log(logging.WARNING, msg, kwargs)
    
    def error(self, msg: str, **kwargs):
        self._log(logging.ERROR, msg, kwargs)

log = LoggerAdapter(logger)

# 在工具中使用
@app.call_tool()
async def process_data(arguments: dict, request_context):
    # 设置请求ID
    request_id = str(uuid.uuid4())
    request_id_var.set(request_id)
    
    start_time = time.time()
    
    log.info(
        "工具调用开始",
        tool="process_data",
        request_id=request_id,
        arguments_size=len(str(arguments))
    )
    
    try:
        # 执行业务逻辑
        result = await do_processing(arguments)
        
        duration = time.time() - start_time
        log.info(
            "工具调用成功",
            tool="process_data",
            duration_ms=duration * 1000,
            result_size=len(str(result))
        )
        
        return result
        
    except Exception as e:
        duration = time.time() - start_time
        log.error(
            "工具调用失败",
            tool="process_data",
            duration_ms=duration * 1000,
            error_type=type(e).__name__,
            error_message=str(e)
        )
        raise

10.2 性能指标收集

from dataclasses import dataclass, field
from typing import Dict, List
from collections import defaultdict
import statistics

@dataclass
class MetricsCollector:
    """指标收集器"""
    
    # 工具调用计数
    tool_calls: Dict[str, int] = field(default_factory=lambda: defaultdict(int))
    
    # 响应时间记录(毫秒)
    response_times: Dict[str, List[float]] = field(default_factory=lambda: defaultdict(list))
    
    # 错误计数
    errors: Dict[str, int] = field(default_factory=lambda: defaultdict(int))
    
    # 资源访问统计
    resource_access: Dict[str, int] = field(default_factory=lambda: defaultdict(int))
    
    def record_tool_call(self, tool_name: str, duration_ms: float, success: bool = True):
        """记录工具调用"""
        self.tool_calls[tool_name] += 1
        self.response_times[tool_name].append(duration_ms)
        if not success:
            self.errors[tool_name] += 1
    
    def record_resource_access(self, uri: str):
        """记录资源访问"""
        self.resource_access[uri] += 1
    
    def get_tool_stats(self, tool_name: str) -> dict:
        """获取工具统计信息"""
        times = self.response_times.get(tool_name, [])
        if not times:
            return {}
        
        return {
            "total_calls": self.tool_calls[tool_name],
            "error_count": self.errors[tool_name],
            "error_rate": self.errors[tool_name] / self.tool_calls[tool_name],
            "avg_response_time": statistics.mean(times),
            "p95_response_time": sorted(times)[int(len(times) * 0.95)],
            "max_response_time": max(times)
        }
    
    def get_summary(self) -> dict:
        """获取汇总报告"""
        return {
            "total_tool_calls": sum(self.tool_calls.values()),
            "total_errors": sum(self.errors.values()),
            "top_tools": sorted(
                self.tool_calls.items(), 
                key=lambda x: x[1], 
                reverse=True
            )[:5],
            "slowest_tools": sorted(
                [(name, statistics.mean(times)) for name, times in self.response_times.items()],
                key=lambda x: x[1],
                reverse=True
            )[:5]
        }

# 全局指标收集器
metrics = MetricsCollector()

# 在工具中使用
@app.call_tool()
async def monitored_tool(arguments: dict):
    start = time.time()
    success = True
    
    try:
        result = await actual_work(arguments)
        return result
    except Exception as e:
        success = False
        raise
    finally:
        duration = (time.time() - start) * 1000
        metrics.record_tool_call("monitored_tool", duration, success)

# 健康检查端点
@app.call_tool()
async def get_health_metrics():
    """获取健康指标(用于监控)"""
    return {
        "content": [{
            "type": "text",
            "text": json.dumps(metrics.get_summary(), indent=2)
        }]
    }

10.3 健康检查

import asyncio
from enum import Enum
from dataclasses import dataclass
from typing import List, Dict
import psutil

class HealthStatus(Enum):
    HEALTHY = "healthy"
    DEGRADED = "degraded"
    UNHEALTHY = "unhealthy"

@dataclass
class HealthCheck:
    name: str
    status: HealthStatus
    message: str
    details: Dict = None

class HealthChecker:
    """健康检查器"""
    
    def __init__(self):
        self.checks = []
    
    def add_check(self, name: str, check_func):
        """添加检查项"""
        self.checks.append((name, check_func))
    
    async def run_checks(self) -> List[HealthCheck]:
        """运行所有检查"""
        results = []
        
        for name, check_func in self.checks:
            try:
                result = await check_func()
                results.append(HealthCheck(name, result["status"], result["message"], result.get("details")))
            except Exception as e:
                results.append(HealthCheck(name, HealthStatus.UNHEALTHY, str(e)))
        
        return results
    
    async def get_overall_status(self) -> Dict:
        """获取整体健康状态"""
        checks = await self.run_checks()
        
        # 确定整体状态
        if any(c.status == HealthStatus.UNHEALTHY for c in checks):
            overall = HealthStatus.UNHEALTHY
        elif any(c.status == HealthStatus.DEGRADED for c in checks):
            overall = HealthStatus.DEGRADED
        else:
            overall = HealthStatus.HEALTHY
        
        return {
            "status": overall.value,
            "timestamp": datetime.utcnow().isoformat(),
            "checks": [
                {
                    "name": c.name,
                    "status": c.status.value,
                    "message": c.message,
                    "details": c.details
                }
                for c in checks
            ]
        }

# 创建健康检查器
health_checker = HealthChecker()

# 系统资源检查
async def check_system_resources():
    """检查系统资源"""
    cpu_percent = psutil.cpu_percent(interval=1)
    memory = psutil.virtual_memory()
    
    if cpu_percent > 90 or memory.percent > 90:
        status = HealthStatus.UNHEALTHY
        message = f"资源紧张: CPU {cpu_percent}%, 内存 {memory.percent}%"
    elif cpu_percent > 70 or memory.percent > 70:
        status = HealthStatus.DEGRADED
        message = f"资源使用较高: CPU {cpu_percent}%, 内存 {memory.percent}%"
    else:
        status = HealthStatus.HEALTHY
        message = f"资源正常: CPU {cpu_percent}%, 内存 {memory.percent}%"
    
    return {
        "status": status,
        "message": message,
        "details": {"cpu": cpu_percent, "memory": memory.percent}
    }

# 数据库连接检查
async def check_database():
    """检查数据库连接"""
    try:
        # 尝试简单查询
        await db.execute("SELECT 1")
        return {
            "status": HealthStatus.HEALTHY,
            "message": "数据库连接正常"
        }
    except Exception as e:
        return {
            "status": HealthStatus.UNHEALTHY,
            "message": f"数据库连接失败: {e}"
        }

# 注册检查
health_checker.add_check("system_resources", check_system_resources)
health_checker.add_check("database", check_database)

@app.call_tool()
async def health_check():
    """健康检查端点"""
    status = await health_checker.get_overall_status()
    
    return {
        "content": [{
            "type": "text",
            "text": json.dumps(status, indent=2)
        }],
        "isError": status["status"] == "unhealthy"
    }

📋 本章总结

本章深入介绍了MCP的高级特性和生产环境最佳实践:

  • 进度通知:让用户了解长时间任务的执行状态
  • 取消请求:支持用户主动取消任务
  • Sampling:Server可以请求Client的LLM生成内容
  • Roots:限制Server的文件系统访问范围
  • 安全实践:输入验证、权限控制、沙箱执行、敏感信息保护
  • 性能优化:连接复用、批量操作、缓存策略
  • 错误处理:有意义的错误信息、错误分类、重试机制
  • 版本管理:Server版本声明、向后兼容
  • 测试策略:单元测试、集成测试、性能测试
  • 监控日志:结构化日志、指标收集、健康检查
🎉 恭喜!你已经掌握了MCP的所有核心知识。在下一章中,我们将通过3个完整的实战项目来综合运用这些知识。