🛠️ 实战案例3:AI数据分析助手

使用自然语言对话完成数据分析全流程:从数据上传到洞察报告,让AI成为你的数据分析专家。

难度:中级 预计时间:60分钟 技术栈:Python + Vue3 + LangChain

一、项目概述

1.1 数据分析的痛点

传统的数据分析流程往往面临以下挑战:

  • SQL难写:需要记忆复杂的SQL语法,处理多表关联、聚合计算时容易出错
  • 图表制作繁琐:选择图表类型、调整样式、配置参数耗时费力
  • 报告撰写耗时:手动整理分析结果,编写文字说明,效率低下
  • 技术门槛高:非技术人员难以独立完成数据分析任务

1.2 解决方案

本案例构建一个AI数据分析助手,实现以下流程:

graph LR A[自然语言提问] --> B[LLM理解意图] B --> C[生成SQL/Python] C --> D[执行查询] D --> E[自动可视化] E --> F[撰写分析报告] F --> G[返回结果]

1.3 核心功能

功能模块 说明
📤 数据上传 支持CSV、Excel、JSON等格式,自动推断数据类型
💬 智能查询 自然语言提问,自动生成SQL并执行
📊 自动可视化 根据数据特征推荐最佳图表类型
📑 报告生成 自动生成包含图表和结论的完整分析报告
🤖 Agent模式 支持多步推理,处理复杂分析任务

1.4 系统架构

graph TB subgraph Frontend["前端 Vue3"] UI1[数据上传组件] UI2[对话界面] UI3[图表展示] UI4[报告编辑器] end subgraph Backend["后端 FastAPI"] API1[数据集管理] API2[查询接口] API3[报告接口] end subgraph Services["核心服务"] S1[数据加载器] S2[Schema分析器] S3[查询生成器] S4[可视化引擎] S5[报告生成器] end subgraph AI["AI能力"] LLM[大语言模型] Agent[LangChain Agent] end subgraph Storage["存储"] DB[(SQLite/PostgreSQL)] File[文件存储] end UI1 --> API1 UI2 --> API2 UI3 --> API3 API1 --> S1 API2 --> S3 API3 --> S5 S1 --> S2 S3 --> S4 S3 --> LLM S5 --> Agent S1 --> File S2 --> DB

二、技术选型

层级 技术 说明
前端 Vue3 + TypeScript + Vite 响应式UI,组件化开发
图表库 ECharts / Plotly 丰富的图表类型和交互能力
后端 Python FastAPI 高性能异步API框架
数据分析 Pandas + NumPy 数据处理和分析基础库
可视化 Matplotlib + Plotly 静态和交互式图表
LLM框架 LangChain 大语言模型应用开发框架
数据库 SQLite(开发)/ PostgreSQL(生产) 轻量级或生产级数据存储
部署 Docker + Nginx 容器化部署,反向代理

三、核心能力设计

3.1 数据理解

系统自动分析数据集,生成数据字典:

  • 类型推断:自动识别数值型、分类型、时间型、文本型字段
  • 统计特征:计算均值、中位数、标准差、分位数等统计指标
  • 语义识别:识别时间列、ID列、地理坐标等特殊字段
  • 质量评估:检测缺失值、异常值、重复数据

3.2 查询生成

将自然语言转换为可执行代码:

  • SQL生成:针对结构化数据,生成标准SQL查询
  • Pandas代码:针对DataFrame操作,生成Python代码
  • 混合模式:复杂任务结合SQL过滤和Python分析

3.3 可视化推荐

根据数据特征自动选择图表类型:

数据类型 推荐图表 适用场景
时间序列 折线图、面积图 趋势分析、季节性检测
分类对比 柱状图、条形图 排名、对比分析
占比分布 饼图、环形图、树图 构成分析、市场份额
数值分布 直方图、箱线图、小提琴图 分布特征、异常检测
双变量关系 散点图、气泡图 相关性分析、聚类
地理数据 地图、热力图 区域分析、空间分布

3.4 洞察提取

自动发现数据中的关键信息:

  • 趋势检测:识别增长/下降趋势,计算变化率
  • 异常发现:基于统计方法检测异常值和异常点
  • 关联分析:计算相关系数,发现变量间关系
  • 智能摘要:使用LLM生成自然语言分析结论

四、后端开发

4.1 项目结构

data-analyst/
├── app/
│   ├── __init__.py
│   ├── main.py                 # FastAPI主程序
│   ├── config.py               # 配置管理
│   ├── database.py             # 数据库连接
│   ├── models/
│   │   ├── __init__.py
│   │   ├── dataset.py          # 数据集模型
│   │   ├── query.py            # 查询记录
│   │   └── report.py           # 报告模型
│   ├── services/
│   │   ├── __init__.py
│   │   ├── data_loader.py      # 数据加载(CSV/Excel/JSON)
│   │   ├── schema_analyzer.py  # 表结构分析
│   │   ├── query_generator.py  # SQL/Python生成
│   │   ├── executor.py         # 查询执行
│   │   ├── visualizer.py       # 图表生成
│   │   └── reporter.py         # 报告生成
│   ├── agents/
│   │   └── data_agent.py       # LangChain Agent
│   └── routers/
│       ├── dataset.py          # 数据集API
│       ├── query.py            # 查询API
│       └── report.py           # 报告API
├── uploads/                    # 上传文件存储
├── static/charts/              # 生成的图表
├── notebooks/                  # Jupyter调试笔记
├── tests/                      # 单元测试
├── sample_data/                # 示例数据
│   └── sample_sales.csv
├── requirements.txt
└── Dockerfile

4.2 核心源码详解

4.2.1 data_loader.py - 数据加载服务

# app/services/data_loader.py
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, Any, Tuple
import logging

logger = logging.getLogger(__name__)

class DataLoader:
    """支持多种格式的数据加载器,自动进行数据清洗和类型推断"""
    
    def __init__(self, upload_dir: str = "uploads"):
        self.upload_dir = Path(upload_dir)
        self.upload_dir.mkdir(exist_ok=True)
    
    def load(self, file_path: str) -> Tuple[pd.DataFrame, Dict[str, Any]]:
        """根据文件后缀自动选择加载方式"""
        path = Path(file_path)
        suffix = path.suffix.lower()
        
        try:
            if suffix == '.csv':
                df = self._load_csv(path)
            elif suffix in ['.xlsx', '.xls']:
                df = self._load_excel(path)
            elif suffix == '.json':
                df = self._load_json(path)
            elif suffix == '.parquet':
                df = pd.read_parquet(path)
            else:
                raise ValueError(f"不支持的文件格式: {suffix}")
            
            # 数据清洗
            df = self._clean_data(df)
            summary = self._generate_summary(df)
            
            logger.info(f"成功加载数据: {path.name}, 行数: {len(df)}")
            return df, summary
            
        except Exception as e:
            logger.error(f"加载文件失败: {e}")
            raise
    
    def _load_csv(self, path: Path) -> pd.DataFrame:
        """智能加载CSV,自动检测编码和分隔符"""
        encodings = ['utf-8', 'gbk', 'gb2312', 'latin1']
        
        for encoding in encodings:
            try:
                sample = pd.read_csv(path, nrows=5, encoding=encoding)
                if len(sample.columns) == 1:
                    df = pd.read_csv(path, encoding=encoding, sep=None, engine='python')
                else:
                    df = pd.read_csv(path, encoding=encoding)
                return df
            except UnicodeDecodeError:
                continue
        raise ValueError("无法识别文件编码")
    
    def _load_excel(self, path: Path) -> pd.DataFrame:
        """加载Excel文件"""
        xl = pd.ExcelFile(path)
        if len(xl.sheet_names) == 1:
            return pd.read_excel(path)
        return pd.read_excel(path, sheet_name=xl.sheet_names[0])
    
    def _clean_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """数据清洗"""
        df = df.dropna(how='all').dropna(axis=1, how='all')
        df.columns = df.columns.str.strip()
        df = self._infer_types(df)
        return df
    
    def _infer_types(self, df: pd.DataFrame) -> pd.DataFrame:
        """智能类型推断"""
        for col in df.columns:
            if df[col].dtype in ['int64', 'float64', 'bool', 'datetime64[ns]']:
                continue
            
            # 尝试转换为数值
            try:
                converted = pd.to_numeric(df[col], errors='coerce')
                if converted.notna().sum() / len(df) > 0.8:
                    df[col] = converted
                    continue
            except:
                pass
            
            # 尝试转换为日期
            try:
                converted = pd.to_datetime(df[col], errors='coerce')
                if converted.notna().sum() / len(df) > 0.8:
                    df[col] = converted
                    continue
            except:
                pass
            
            # 分类变量检测
            if df[col].dtype == 'object':
                unique_ratio = df[col].nunique() / len(df)
                if unique_ratio < 0.05 and df[col].nunique() < 50:
                    df[col] = df[col].astype('category')
        return df
    
    def _generate_summary(self, df: pd.DataFrame) -> Dict[str, Any]:
        """生成数据摘要"""
        return {
            'row_count': len(df),
            'column_count': len(df.columns),
            'columns': list(df.columns),
            'dtypes': {col: str(dtype) for col, dtype in df.dtypes.items()},
            'memory_usage': df.memory_usage(deep=True).sum(),
            'missing_values': df.isnull().sum().to_dict(),
            'missing_percentage': (df.isnull().sum() / len(df) * 100).round(2).to_dict()
        }
    
    def get_schema(self, df: pd.DataFrame) -> Dict[str, Any]:
        """获取详细的表结构信息"""
        schema = {
            'columns': [],
            'total_rows': len(df),
            'total_columns': len(df.columns)
        }
        
        for col in df.columns:
            col_info = {
                'name': col,
                'dtype': str(df[col].dtype),
                'missing_count': int(df[col].isnull().sum()),
                'missing_pct': round(df[col].isnull().sum() / len(df) * 100, 2),
                'unique_count': int(df[col].nunique())
            }
            
            if pd.api.types.is_numeric_dtype(df[col]):
                col_info.update({
                    'type': 'numeric',
                    'min': float(df[col].min()) if not pd.isna(df[col].min()) else None,
                    'max': float(df[col].max()) if not pd.isna(df[col].max()) else None,
                    'mean': float(df[col].mean()) if not pd.isna(df[col].mean()) else None,
                    'std': float(df[col].std()) if not pd.isna(df[col].std()) else None,
                    'median': float(df[col].median()) if not pd.isna(df[col].median()) else None
                })
            elif pd.api.types.is_datetime64_any_dtype(df[col]):
                col_info.update({
                    'type': 'datetime',
                    'min': str(df[col].min()),
                    'max': str(df[col].max())
                })
            elif df[col].dtype.name == 'category':
                col_info.update({
                    'type': 'categorical',
                    'categories': list(df[col].cat.categories)[:20]
                })
            else:
                col_info.update({
                    'type': 'text',
                    'sample_values': df[col].dropna().head(5).tolist()
                })
            
            schema['columns'].append(col_info)
        
        return schema

4.2.2 query_generator.py - 自然语言转SQL

# app/services/query_generator.py
import re
import json
from typing import Dict, List, Any, Optional
from dataclasses import dataclass

@dataclass
class GeneratedQuery:
    """生成的查询结果"""
    query_type: str  # 'sql' 或 'python'
    code: str
    explanation: str
    confidence: float
    suggested_viz: Optional[str] = None

class QueryGenerator:
    """自然语言查询生成器"""
    
    AGGREGATIONS = {
        '总和': 'SUM', '总计': 'SUM', '平均': 'AVG', '平均值': 'AVG',
        '最大': 'MAX', '最高': 'MAX', '最小': 'MIN', '最低': 'MIN',
        '数量': 'COUNT', '计数': 'COUNT'
    }
    
    VIZ_RECOMMENDATIONS = {
        '趋势': 'line', '变化': 'line', '增长': 'line',
        '对比': 'bar', '比较': 'bar', '排名': 'bar',
        '占比': 'pie', '构成': 'pie', '分布': 'histogram',
        '相关': 'scatter', '关系': 'scatter'
    }
    
    def __init__(self, llm_client=None):
        self.llm = llm_client
    
    def generate(self, question: str, schema: Dict[str, Any], 
                 dialect: str = 'sqlite') -> GeneratedQuery:
        """根据自然语言问题生成查询"""
        # 首先尝试规则匹配
        rule_result = self._try_rule_based(question, schema, dialect)
        if rule_result and rule_result.confidence > 0.8:
            return rule_result
        
        # 规则匹配失败,使用LLM
        if self.llm:
            return self._generate_with_llm(question, schema, dialect)
        
        return self._generate_fallback(question, schema)
    
    def _try_rule_based(self, question: str, schema: Dict, dialect: str) -> Optional[GeneratedQuery]:
        """基于规则的查询生成"""
        columns = schema.get('columns', [])
        selected_cols = self._extract_columns(question, columns)
        aggregation = self._extract_aggregation(question)
        group_by = self._extract_group_by(question, columns)
        
        # 构建SQL
        if not selected_cols:
            selected_cols = ['*']
        
        select_parts = []
        for col in selected_cols:
            if aggregation and col != '*':
                select_parts.append(f"{aggregation}({col}) as {col}_{aggregation.lower()}")
            else:
                select_parts.append(col)
        
        sql_parts = ["SELECT", ", ".join(select_parts), "FROM dataset"]
        
        if group_by:
            sql_parts.append(f"GROUP BY {group_by}")
        
        sql = " ".join(sql_parts)
        viz = self._recommend_visualization(question)
        
        return GeneratedQuery(
            query_type='sql',
            code=sql,
            explanation=f"基于规则生成: {question}",
            confidence=0.7 if aggregation else 0.6,
            suggested_viz=viz
        )
    
    def _extract_columns(self, question: str, columns: List[Dict]) -> List[str]:
        """提取问题中涉及的列"""
        selected = []
        col_names = [c['name'] for c in columns]
        
        for col in col_names:
            if col.lower() in question.lower():
                selected.append(col)
        
        return selected
    
    def _extract_aggregation(self, question: str) -> Optional[str]:
        """提取聚合函数"""
        for keyword, func in self.AGGREGATIONS.items():
            if keyword in question:
                return func
        return None
    
    def _extract_group_by(self, question: str, columns: List[Dict]) -> Optional[str]:
        """提取分组字段"""
        patterns = [
            r'按(.*?)(分组|统计|汇总|计算)',
            r'每个(.*?)的',
            r'各(.*?)的'
        ]
        
        for pattern in patterns:
            match = re.search(pattern, question)
            if match:
                group_keyword = match.group(1).strip()
                for c in columns:
                    if c['name'] == group_keyword or group_keyword in c['name']:
                        return c['name']
        return None
    
    def _recommend_visualization(self, question: str) -> str:
        """推荐图表类型"""
        for keyword, chart_type in self.VIZ_RECOMMENDATIONS.items():
            if keyword in question:
                return chart_type
        return 'table'
    
    def _generate_with_llm(self, question: str, schema: Dict, dialect: str) -> GeneratedQuery:
        """使用LLM生成查询"""
        prompt = f"""你是一个数据分析专家。请根据以下数据表结构和用户问题,生成{dialect} SQL查询。

数据表结构:
{json.dumps(schema, indent=2, ensure_ascii=False)}

用户问题: {question}

请输出JSON格式:
{{
    "sql": "生成的SQL语句",
    "explanation": "查询说明",
    "visualization": "推荐的图表类型"
}}"""
        
        try:
            response = self.llm.generate(prompt)
            result = json.loads(response)
            
            return GeneratedQuery(
                query_type='sql',
                code=result['sql'],
                explanation=result['explanation'],
                confidence=0.9,
                suggested_viz=result.get('visualization', 'table')
            )
        except Exception as e:
            return self._generate_fallback(question, schema)
    
    def _generate_fallback(self, question: str, schema: Dict) -> GeneratedQuery:
        """生成回退查询"""
        return GeneratedQuery(
            query_type='sql',
            code="SELECT * FROM dataset LIMIT 100",
            explanation="无法解析问题,返回前100条数据",
            confidence=0.3,
            suggested_viz='table'
        )

4.2.3 visualizer.py - 自动可视化

# app/services/visualizer.py
import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from typing import Dict, List, Any
from pathlib import Path
import json

class AutoVisualizer:
    """自动可视化引擎"""
    
    CHART_TYPES = {
        'line': '折线图', 'bar': '柱状图', 'horizontal_bar': '条形图',
        'pie': '饼图', 'donut': '环形图', 'scatter': '散点图',
        'histogram': '直方图', 'box': '箱线图', 'heatmap': '热力图',
        'area': '面积图', 'table': '数据表'
    }
    
    def __init__(self, output_dir: str = "static/charts"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
    
    def recommend_chart(self, df: pd.DataFrame, columns: List[str] = None) -> Dict[str, Any]:
        """根据数据特征推荐最佳图表类型"""
        if columns is None:
            columns = list(df.columns)
        
        if len(columns) == 0:
            return {'type': 'table', 'reason': '未指定列'}
        
        col_types = self._analyze_column_types(df, columns)
        
        # 单变量分析
        if len(columns) == 1:
            col = columns[0]
            dtype = col_types[col]
            
            if dtype == 'numeric':
                if df[col].nunique() < 20:
                    return {'type': 'bar', 'x': col, 'reason': '数值类别适合柱状图'}
                return {'type': 'histogram', 'x': col, 'reason': '连续数值适合直方图'}
            elif dtype == 'datetime':
                return {'type': 'line', 'x': col, 'reason': '时间序列适合折线图'}
            elif dtype == 'categorical':
                if df[col].nunique() <= 10:
                    return {'type': 'pie', 'names': col, 'reason': '类别较少适合饼图'}
                return {'type': 'bar', 'x': col, 'reason': '类别数据适合柱状图'}
        
        # 双变量分析
        if len(columns) == 2:
            col1, col2 = columns
            type1, type2 = col_types[col1], col_types[col2]
            
            if (type1 == 'datetime' and type2 == 'numeric') or (type2 == 'datetime' and type1 == 'numeric'):
                time_col = col1 if type1 == 'datetime' else col2
                val_col = col2 if type1 == 'datetime' else col1
                return {'type': 'line', 'x': time_col, 'y': val_col, 'reason': '时间+数值适合折线图'}
            
            if (type1 == 'categorical' and type2 == 'numeric') or (type2 == 'categorical' and type1 == 'numeric'):
                cat_col = col1 if type1 == 'categorical' else col2
                val_col = col2 if type1 == 'categorical' else col1
                return {'type': 'bar', 'x': cat_col, 'y': val_col, 'reason': '分类+数值适合柱状图'}
            
            if type1 == 'numeric' and type2 == 'numeric':
                return {'type': 'scatter', 'x': col1, 'y': col2, 'reason': '双数值适合散点图'}
        
        return {'type': 'table', 'columns': columns, 'reason': '默认表格展示'}
    
    def _analyze_column_types(self, df: pd.DataFrame, columns: List[str]) -> Dict[str, str]:
        """分析列类型"""
        types = {}
        for col in columns:
            if col not in df.columns:
                types[col] = 'unknown'
            elif pd.api.types.is_datetime64_any_dtype(df[col]):
                types[col] = 'datetime'
            elif pd.api.types.is_numeric_dtype(df[col]):
                types[col] = 'numeric'
            elif df[col].dtype.name == 'category':
                types[col] = 'categorical'
            else:
                unique_ratio = df[col].nunique() / len(df)
                if unique_ratio < 0.1 or df[col].nunique() < 20:
                    types[col] = 'categorical'
                else:
                    types[col] = 'text'
        return types
    
    def generate(self, df: pd.DataFrame, chart_config: Dict[str, Any], 
                 title: str = None) -> Dict[str, Any]:
        """生成图表"""
        chart_type = chart_config.get('type', 'table')
        
        try:
            if chart_type == 'table':
                return self._generate_table(df, chart_config, title)
            elif chart_type == 'line':
                return self._generate_line(df, chart_config, title)
            elif chart_type == 'bar':
                return self._generate_bar(df, chart_config, title)
            elif chart_type == 'pie':
                return self._generate_pie(df, chart_config, title)
            elif chart_type == 'scatter':
                return self._generate_scatter(df, chart_config, title)
            elif chart_type == 'histogram':
                return self._generate_histogram(df, chart_config, title)
            else:
                return self._generate_table(df, chart_config, title)
        except Exception as e:
            return {'error': str(e), 'type': 'error'}
    
    def _generate_line(self, df: pd.DataFrame, config: Dict, title: str) -> Dict:
        """生成折线图"""
        x = config.get('x', df.columns[0])
        y = config.get('y', df.columns[1] if len(df.columns) > 1 else df.columns[0])
        
        fig = px.line(df, x=x, y=y, title=title or f"{y} 趋势", width=800, height=500)
        fig.update_traces(mode='lines+markers')
        return self._save_plotly_chart(fig, 'line')
    
    def _generate_bar(self, df: pd.DataFrame, config: Dict, title: str) -> Dict:
        """生成柱状图"""
        x = config.get('x', df.columns[0])
        y = config.get('y', df.columns[1] if len(df.columns) > 1 else None)
        
        if y is None:
            df = df.groupby(x).size().reset_index(name='count')
            y = 'count'
        
        fig = px.bar(df, x=x, y=y, title=title or f"{y} 按 {x} 分组", width=800, height=500)
        return self._save_plotly_chart(fig, 'bar')
    
    def _generate_pie(self, df: pd.DataFrame, config: Dict, title: str) -> Dict:
        """生成饼图"""
        names = config.get('names', df.columns[0])
        values = config.get('values', df.columns[1] if len(df.columns) > 1 else None)
        
        if values is None:
            df = df.groupby(names).size().reset_index(name='count')
            values = 'count'
        
        fig = px.pie(df, names=names, values=values, title=title or f"{names} 分布", width=800, height=500)
        fig.update_traces(textposition='inside', textinfo='percent+label')
        return self._save_plotly_chart(fig, 'pie')
    
    def _generate_scatter(self, df: pd.DataFrame, config: Dict, title: str) -> Dict:
        """生成散点图"""
        x = config.get('x', df.columns[0])
        y = config.get('y', df.columns[1] if len(df.columns) > 1 else df.columns[0])
        
        fig = px.scatter(df, x=x, y=y, title=title or f"{x} vs {y}", width=800, height=500, opacity=0.7)
        return self._save_plotly_chart(fig, 'scatter')
    
    def _generate_histogram(self, df: pd.DataFrame, config: Dict, title: str) -> Dict:
        """生成直方图"""
        x = config.get('x', df.columns[0])
        fig = px.histogram(df, x=x, title=title or f"{x} 分布", width=800, height=500, marginal='box')
        return self._save_plotly_chart(fig, 'histogram')
    
    def _generate_table(self, df: pd.DataFrame, config: Dict, title: str) -> Dict:
        """生成数据表"""
        columns = config.get('columns', list(df.columns))
        max_rows = config.get('max_rows', 100)
        display_df = df[columns].head(max_rows)
        
        fig = go.Figure(data=[go.Table(
            header=dict(
                values=list(display_df.columns),
                fill_color='#2c3e50',
                align='left',
                font=dict(color='white', size=12)
            ),
            cells=dict(
                values=[display_df[col] for col in display_df.columns],
                fill_color=[['#ecf0f1', '#bdc3c7'] * (len(display_df) // 2 + 1)][:len(display_df)],
                align='left',
                font=dict(size=11)
            )
        )])
        
        fig.update_layout(title=title or '数据预览', width=800, height=min(500, 100 + len(display_df) * 30))
        return self._save_plotly_chart(fig, 'table')
    
    def _save_plotly_chart(self, fig, chart_type: str) -> Dict[str, Any]:
        """保存Plotly图表"""
        import uuid
        chart_id = str(uuid.uuid4())[:8]
        filename = f"{chart_type}_{chart_id}.html"
        filepath = self.output_dir / filename
        
        fig.write_html(str(filepath), include_plotlyjs='cdn')
        
        return {
            'type': chart_type,
            'html_path': f"/static/charts/{filename}",
            'title': fig.layout.title.text if fig.layout.title else None,
            'width': 800,
            'height': 500
        }

4.2.4 main.py - FastAPI主程序

# app/main.py
"""AI数据分析助手 - FastAPI主程序"""
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from typing import List, Dict, Any
from datetime import datetime
import shutil
import pandas as pd
from pathlib import Path

from app.services.data_loader import DataLoader
from app.services.schema_analyzer import SchemaAnalyzer
from app.services.query_generator import QueryGenerator
from app.services.visualizer import AutoVisualizer

app = FastAPI(
    title="AI数据分析助手 API",
    description="支持自然语言查询、自动可视化和报告生成的数据分析平台",
    version="1.0.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.mount("/static", StaticFiles(directory="static"), name="static")

# 服务实例
data_loader = DataLoader(upload_dir="uploads")
schema_analyzer = SchemaAnalyzer()
query_generator = QueryGenerator()
visualizer = AutoVisualizer(output_dir="static/charts")

# 内存存储
datasets: Dict[str, Dict] = {}

@app.post("/api/datasets/upload")
async def upload_dataset(file: UploadFile = File(...)):
    """上传数据集文件(支持CSV、Excel、JSON、Parquet)"""
    try:
        file_ext = Path(file.filename).suffix.lower()
        allowed_exts = ['.csv', '.xlsx', '.xls', '.json', '.parquet']
        
        if file_ext not in allowed_exts:
            raise HTTPException(400, f"不支持的文件格式: {file_ext}")
        
        dataset_id = f"ds_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        file_path = Path("uploads") / f"{dataset_id}{file_ext}"
        
        with open(file_path, "wb") as buffer:
            shutil.copyfileobj(file.file, buffer)
        
        # 加载并分析数据
        df, summary = data_loader.load(str(file_path))
        schema = data_loader.get_schema(df)
        analysis = schema_analyzer.analyze(df)
        
        datasets[dataset_id] = {
            'id': dataset_id,
            'filename': file.filename,
            'row_count': len(df),
            'column_count': len(df.columns),
            'columns': schema['columns'],
            'schema_analysis': analysis,
            'created_at': datetime.now().isoformat()
        }
        
        # 保存DataFrame供后续使用
        df.to_pickle(f"uploads/{dataset_id}.pkl")
        
        return {
            'success': True,
            'dataset_id': dataset_id,
            'name': file.filename,
            'row_count': len(df),
            'column_count': len(df.columns),
            'columns': [c['name'] for c in schema['columns']]
        }
        
    except Exception as e:
        raise HTTPException(500, str(e))

@app.get("/api/datasets")
async def list_datasets():
    """获取所有数据集列表"""
    return {
        'datasets': [
            {
                'id': ds['id'],
                'name': ds['filename'],
                'row_count': ds['row_count'],
                'column_count': ds['column_count'],
                'created_at': ds['created_at']
            }
            for ds in datasets.values()
        ]
    }

@app.get("/api/datasets/{dataset_id}")
async def get_dataset(dataset_id: str):
    """获取数据集详细信息"""
    if dataset_id not in datasets:
        raise HTTPException(404, "数据集不存在")
    return datasets[dataset_id]

@app.get("/api/datasets/{dataset_id}/preview")
async def preview_dataset(dataset_id: str, rows: int = 10):
    """预览数据集前N行"""
    if dataset_id not in datasets:
        raise HTTPException(404, "数据集不存在")
    
    try:
        df = pd.read_pickle(f"uploads/{dataset_id}.pkl")
        preview = df.head(rows)
        
        return {
            'columns': list(df.columns),
            'data': preview.to_dict('records'),
            'total_rows': len(df)
        }
    except Exception as e:
        raise HTTPException(500, str(e))

@app.post("/api/query")
async def execute_query(query_request: Dict[str, Any]):
    """执行自然语言查询"""
    dataset_id = query_request.get('dataset_id')
    question = query_request.get('question')
    generate_chart = query_request.get('generate_chart', True)
    
    if not dataset_id or not question:
        raise HTTPException(400, "缺少必要参数: dataset_id 或 question")
    
    if dataset_id not in datasets:
        raise HTTPException(404, "数据集不存在")
    
    try:
        df = pd.read_pickle(f"uploads/{dataset_id}.pkl")
        schema = datasets[dataset_id]['schema_analysis']
        
        # 生成查询
        generated = query_generator.generate(question, schema)
        
        # 执行查询(简化实现)
        result_df = df.head(1000)
        
        # 生成图表
        chart_info = None
        if generate_chart and not result_df.empty:
            chart_config = visualizer.recommend_chart(result_df)
            chart_config.update({'type': generated.suggested_viz or chart_config['type']})
            chart_info = visualizer.generate(result_df, chart_config, title=question[:30])
        
        return {
            'success': True,
            'question': question,
            'explanation': generated.explanation,
            'generated_code': generated.code,
            'result': {
                'columns': list(result_df.columns),
                'data': result_df.head(100).to_dict('records'),
                'total_rows': len(result_df)
            },
            'chart': chart_info,
            'suggested_followups': [
                '数据的整体统计信息',
                '各字段的分布情况',
                '相关性分析'
            ]
        }
        
    except Exception as e:
        raise HTTPException(500, str(e))

@app.post("/api/visualize")
async def create_visualization(request: Dict[str, Any]):
    """创建可视化图表"""
    dataset_id = request.get('dataset_id')
    chart_type = request.get('chart_type', 'auto')
    config = request.get('config', {})
    
    if dataset_id not in datasets:
        raise HTTPException(404, "数据集不存在")
    
    try:
        df = pd.read_pickle(f"uploads/{dataset_id}.pkl")
        
        if chart_type == 'auto':
            chart_config = visualizer.recommend_chart(df)
        else:
            chart_config = {'type': chart_type, **config}
        
        chart_info = visualizer.generate(df, chart_config)
        
        return {'success': True, 'chart': chart_info}
        
    except Exception as e:
        raise HTTPException(500, str(e))

@app.get("/api/health")
async def health_check():
    """健康检查端点"""
    return {
        'status': 'healthy',
        'version': '1.0.0',
        'datasets_count': len(datasets)
    }

4.2.5 requirements.txt

# 基础框架
fastapi==0.104.1
uvicorn[standard]==0.24.0
python-multipart==0.0.6

# 数据处理
pandas==2.1.3
numpy==1.26.2
openpyxl==3.1.2
pyarrow==14.0.1

# 可视化
matplotlib==3.8.2
plotly==5.18.0

# 报告生成
jinja2==3.1.2

# 工具
python-dotenv==1.0.0
pydantic==2.5.2

# 测试
pytest==7.4.3
pytest-asyncio==0.21.1

五、前端开发

5.1 项目结构

frontend/
├── public/
│   └── sample_data/
│       └── sample_sales.csv
├── src/
│   ├── components/
│   │   ├── DataUploader.vue      # 数据上传组件
│   │   ├── SchemaViewer.vue      # 表结构展示
│   │   ├── QueryBuilder.vue      # 可视化查询构建
│   │   ├── ChatInterface.vue     # 对话式查询界面
│   │   ├── ChartViewer.vue       # 图表展示
│   │   └── DataTable.vue         # 数据表格
│   ├── views/
│   │   ├── Home.vue              # 首页
│   │   ├── Dataset.vue           # 数据集管理
│   │   └── Analysis.vue          # 分析页面
│   ├── api/
│   │   └── index.js              # API接口封装
│   ├── App.vue
│   └── main.js
├── index.html
├── package.json
└── vite.config.js

5.2 核心组件源码

5.2.1 ChatInterface.vue - 对话式分析界面

<template>
  <div class="chat-interface">
    <div class="chat-messages" ref="messagesContainer">
      <div v-for="(msg, index) in messages" :key="index" :class="['message', msg.type]">
        <template v-if="msg.type === 'user'">
          <div class="message-content">
            <div class="message-bubble">{{ msg.content }}</div>
          </div>
          <div class="message-avatar">👤</div>
        </template>
        <template v-else>
          <div class="message-avatar">🤖</div>
          <div class="message-content">
            <div class="message-bubble">
              <p v-if="msg.explanation">{{ msg.explanation }}</p>
              <div v-if="msg.chart" class="chart-container">
                <iframe :src="msg.chart.html_path" frameborder="0" :style="{ width: '100%', height: msg.chart.height + 'px' }"></iframe>
              </div>
              <DataTable v-if="msg.result && msg.result.data" :columns="msg.result.columns" :data="msg.result.data" />
              <div v-if="msg.suggestions" class="suggestions">
                <span class="suggestions-label">💡 你可能还想问:</span>
                <button v-for="(suggestion, i) in msg.suggestions" :key="i" class="suggestion-chip" @click="sendMessage(suggestion)">{{ suggestion }}</button>
              </div>
            </div>
          </div>
        </template>
      </div>
    </div>
    
    <div class="chat-input-area">
      <div class="input-wrapper">
        <input v-model="inputMessage" type="text" placeholder="输入你的问题,例如:'各产品类别的销售额占比'" @keyup.enter="sendMessage()" />
        <button class="send-btn" @click="sendMessage()" :disabled="!inputMessage.trim()">发送</button>
      </div>
    </div>
  </div>
</template>

<script setup>
import { ref } from 'vue'
import axios from 'axios'
import DataTable from './DataTable.vue'

const props = defineProps({ datasetId: { type: String, required: true } })

const messages = ref([
  {
    type: 'ai',
    explanation: '你好!我是AI数据分析助手。你可以用自然语言向我提问,我会帮你分析数据并生成可视化图表。',
    suggestions: ['各产品类别的销售额占比', '销售额的趋势如何', '按月份统计销售额']
  }
])
const inputMessage = ref('')

const sendMessage = async (content = null) => {
  const message = content || inputMessage.value.trim()
  if (!message) return
  
  messages.value.push({ type: 'user', content: message })
  inputMessage.value = ''
  
  try {
    const response = await axios.post('/api/query', {
      dataset_id: props.datasetId,
      question: message,
      generate_chart: true
    })
    
    messages.value.push({
      type: 'ai',
      explanation: response.data.explanation,
      result: response.data.result,
      chart: response.data.chart,
      suggestions: response.data.suggested_followups
    })
  } catch (error) {
    messages.value.push({
      type: 'ai',
      explanation: `抱歉,查询失败:${error.message}`
    })
  }
}
</script>

<style scoped>
.chat-interface {
  display: flex;
  flex-direction: column;
  height: 100%;
  background: #f5f5f5;
  border-radius: 12px;
  overflow: hidden;
}
.chat-messages {
  flex: 1;
  overflow-y: auto;
  padding: 20px;
  display: flex;
  flex-direction: column;
  gap: 16px;
}
.message {
  display: flex;
  gap: 12px;
  max-width: 85%;
}
.message.user {
  align-self: flex-end;
  flex-direction: row-reverse;
}
.message.ai {
  align-self: flex-start;
}
.message-avatar {
  width: 36px;
  height: 36px;
  border-radius: 50%;
  display: flex;
  align-items: center;
  justify-content: center;
  background: white;
  font-size: 18px;
  flex-shrink: 0;
}
.message-bubble {
  padding: 12px 16px;
  border-radius: 16px;
  background: white;
  box-shadow: 0 1px 2px rgba(0,0,0,0.1);
}
.message.user .message-bubble {
  background: #4a90d9;
  color: white;
}
.chat-input-area {
  padding: 16px 20px;
  background: white;
  border-top: 1px solid #e0e0e0;
}
.input-wrapper {
  display: flex;
  gap: 8px;
}
.input-wrapper input {
  flex: 1;
  padding: 12px 16px;
  border: 1px solid #ddd;
  border-radius: 24px;
  font-size: 14px;
  outline: none;
}
.send-btn {
  padding: 12px 24px;
  background: #4a90d9;
  color: white;
  border: none;
  border-radius: 24px;
  font-size: 14px;
  cursor: pointer;
}
.suggestions {
  margin-top: 12px;
  padding-top: 12px;
  border-top: 1px dashed #ddd;
}
.suggestion-chip {
  display: inline-block;
  margin: 4px;
  padding: 6px 12px;
  background: #e3f2fd;
  border: none;
  border-radius: 16px;
  font-size: 13px;
  color: #1976d2;
  cursor: pointer;
}
</style>

5.2.2 package.json

{
  "name": "ai-data-analyst-frontend",
  "version": "1.0.0",
  "type": "module",
  "scripts": {
    "dev": "vite",
    "build": "vite build",
    "preview": "vite preview"
  },
  "dependencies": {
    "vue": "^3.3.8",
    "vue-router": "^4.2.5",
    "axios": "^1.6.2",
    "plotly.js-dist": "^2.27.0"
  },
  "devDependencies": {
    "@vitejs/plugin-vue": "^4.5.0",
    "vite": "^5.0.0"
  }
}

六、LangChain Agent实现

6.1 数据分析Agent

# app/agents/data_agent.py
"""基于LangChain的数据分析Agent实现"""
from langchain.agents import Tool, AgentExecutor, create_react_agent
from langchain.prompts import PromptTemplate
import pandas as pd
import json
from typing import Dict, List

class DataAnalysisAgent:
    """智能数据分析Agent"""
    
    def __init__(self, llm, df: pd.DataFrame):
        self.llm = llm
        self.df = df
        self.schema = self._get_schema()
        self.tools = self._create_tools()
        self.agent = self._create_agent()
    
    def _get_schema(self) -> str:
        """获取数据Schema描述"""
        info = [f"数据集共有 {len(self.df)} 行,{len(self.df.columns)} 列\n字段信息:"]
        for col in self.df.columns:
            dtype = str(self.df[col].dtype)
            unique = self.df[col].nunique()
            if pd.api.types.is_numeric_dtype(self.df[col]):
                stats = f"均值={self.df[col].mean():.2f}, 范围=[{self.df[col].min():.2f}, {self.df[col].max():.2f}]"
            else:
                stats = f"唯一值={unique}"
            info.append(f"- {col} ({dtype}): {stats}")
        return "\n".join(info)
    
    def _create_tools(self) -> List[Tool]:
        """创建Agent可用工具"""
        
        def query_data(condition: str) -> str:
            """执行数据查询"""
            try:
                result = self.df.query(condition) if condition else self.df.head(100)
                return f"查询结果(前20行):\n{result.head(20).to_string()}"
            except Exception as e:
                return f"查询失败: {str(e)}"
        
        def calculate_statistics(column: str) -> str:
            """计算统计指标"""
            if column not in self.df.columns:
                return f"错误: 列'{column}'不存在"
            series = self.df[column]
            stats = {
                'count': len(series),
                'missing': series.isnull().sum(),
                'unique': series.nunique()
            }
            if pd.api.types.is_numeric_dtype(series):
                stats.update({
                    'mean': series.mean(),
                    'std': series.std(),
                    'min': series.min(),
                    'max': series.max(),
                    'median': series.median()
                })
            return json.dumps(stats, ensure_ascii=False, indent=2)
        
        def find_correlations() -> str:
            """查找相关性"""
            numeric_df = self.df.select_dtypes(include=['float64', 'int64'])
            if numeric_df.empty:
                return "没有数值列可用于相关性分析"
            corr = numeric_df.corr()
            strong_corr = []
            for i in range(len(corr.columns)):
                for j in range(i+1, len(corr.columns)):
                    val = corr.iloc[i, j]
                    if abs(val) > 0.5:
                        strong_corr.append({
                            'col1': corr.columns[i],
                            'col2': corr.columns[j],
                            'correlation': round(val, 3)
                        })
            strong_corr.sort(key=lambda x: abs(x['correlation']), reverse=True)
            return json.dumps(strong_corr[:10], ensure_ascii=False, indent=2)
        
        def detect_anomalies(column: str) -> str:
            """检测异常值"""
            if column not in self.df.columns:
                return f"错误: 列'{column}'不存在"
            series = self.df[column]
            if not pd.api.types.is_numeric_dtype(series):
                return f"错误: '{column}'不是数值列"
            Q1 = series.quantile(0.25)
            Q3 = series.quantile(0.75)
            IQR = Q3 - Q1
            lower_bound = Q1 - 1.5 * IQR
            upper_bound = Q3 + 1.5 * IQR
            anomalies = series[(series < lower_bound) | (series > upper_bound)]
            return json.dumps({
                'total_anomalies': len(anomalies),
                'anomaly_rate': f"{len(anomalies)/len(series)*100:.2f}%"
            }, ensure_ascii=False, indent=2)
        
        return [
            Tool(name="query_data", func=query_data, description="根据条件查询数据,参数为过滤条件字符串"),
            Tool(name="calculate_statistics", func=calculate_statistics, description="计算指定列的统计指标"),
            Tool(name="find_correlations", func=find_correlations, description="查找数值列之间的相关性"),
            Tool(name="detect_anomalies", func=detect_anomalies, description="检测指定列的异常值")
        ]
    
    def _create_agent(self):
        """创建Agent"""
        template = """你是一个专业的数据分析师。请使用以下工具回答用户问题。

可用工具:
{tools}

工具名称:
{tool_names}

数据表结构:
{schema}

请按照以下格式思考:

问题:需要回答的问题
思考:我应该如何解决这个问题
操作:要使用的工具名称
操作输入:工具的输入参数
观察:工具执行的结果
...(可以重复思考和操作步骤)
思考:我现在知道最终答案
最终答案:对用户问题的完整回答

开始!

问题:{input}
{agent_scratchpad}"""
        
        prompt = PromptTemplate(
            template=template,
            input_variables=["input", "agent_scratchpad"],
            partial_variables={
                "schema": self.schema,
                "tools": "\n".join([f"{t.name}: {t.description}" for t in self.tools]),
                "tool_names": ", ".join([t.name for t in self.tools])
            }
        )
        
        agent = create_react_agent(self.llm, self.tools, prompt)
        return AgentExecutor(agent=agent, tools=self.tools, verbose=True)
    
    def run(self, question: str) -> str:
        """运行Agent分析"""
        try:
            result = self.agent.invoke({"input": question})
            return result['output']
        except Exception as e:
            return f"分析失败: {str(e)}"


# 使用示例
"""
from langchain.llms import OpenAI
import pandas as pd

df = pd.read_csv('sales_data.csv')
llm = OpenAI(temperature=0)

agent = DataAnalysisAgent(llm, df)
result = agent.run("分析销售额的趋势,找出异常月份并解释可能的原因")
print(result)
"""

6.2 Pandas DataFrame Agent(简化版)

# 使用LangChain内置的Pandas Agent
from langchain.agents import create_pandas_dataframe_agent
from langchain.llms import OpenAI

def create_simple_agent(df, api_key=None):
    """创建简单的Pandas分析Agent"""
    llm = OpenAI(temperature=0, openai_api_key=api_key)
    
    agent = create_pandas_dataframe_agent(
        llm,
        df,
        verbose=True,
        allow_dangerous_code=True  # 允许执行代码(注意安全)
    )
    
    return agent

# 示例对话
"""
agent = create_simple_agent(df)

# 示例1:基础统计
agent.run("计算每个月的销售额总和")

# 示例2:复杂分析
agent.run("找出复购率最高的用户群体特征,并分析其消费行为模式")

# 示例3:异常检测
agent.run("检测销售额中的异常值,并分析异常原因")

# 示例4:预测建议
agent.run("基于历史数据,预测下个月的销售额并给出置信区间")
"""

七、完整功能演示

7.1 场景1:销售数据分析

数据:sample_sales.csv(包含订单日期、产品类别、销售额、地区等字段)

步骤1:上传数据

# API调用示例
curl -X POST "http://localhost:8000/api/datasets/upload" \
  -F "file=@sample_sales.csv"

# 响应
{
  "success": true,
  "dataset_id": "ds_20240115_143022",
  "name": "sample_sales.csv",
  "row_count": 5000,
  "column_count": 8
}

步骤2:自然语言查询

# 提问:各产品类别的销售额占比
curl -X POST "http://localhost:8000/api/query" \
  -H "Content-Type: application/json" \
  -d '{
    "dataset_id": "ds_20240115_143022",
    "question": "各产品类别的销售额占比",
    "generate_chart": true
  }'

# 系统处理流程:
# 1. 理解意图 → 需要按类别分组,计算销售额总和,生成占比
# 2. 生成SQL → SELECT category, SUM(sales) as total FROM dataset GROUP BY category
# 3. 执行查询 → 获取结果
# 4. 推荐图表 → pie(占比数据适合饼图)
# 5. 生成图表 → 交互式饼图
# 6. 撰写洞察 → "电子产品占总销售额的35%,是最高的类别..."

步骤3:追问分析

# 用户:电子产品在哪些地区销量最好?
# 系统:自动筛选category='电子产品',按地区分组统计

# 用户:销售额的趋势如何?有没有季节性?
# 系统:按月份聚合,生成折线图,分析季节性模式

7.2 场景2:用户行为分析

# 高级分析:用户分群和RFM分析
"""
用户提问:找出高价值用户群体的特征

Agent执行步骤:
1. 计算每个用户的消费金额、频次、最近消费时间
2. 进行RFM评分
3. 使用K-Means聚类(3-5个群体)
4. 分析每个群体的特征
5. 生成可视化(散点图、雷达图)
6. 输出群体画像和营销建议
"""

# Agent内部执行逻辑
analysis_steps = """
思考:用户想了解高价值用户特征,我应该进行RFM分析和聚类
操作:calculate_statistics
操作输入:user_id, order_amount
观察:用户平均消费金额为...

思考:需要计算每个用户的RFM指标
操作:query_data
操作输入:SELECT user_id, COUNT(*) as frequency, SUM(amount) as monetary, MAX(order_date) as recency FROM orders GROUP BY user_id
观察:获得RFM基础数据...

思考:现在进行聚类分析
操作:python_analysis
操作输入:使用KMeans进行聚类,n_clusters=4
观察:分为4个群体,群体0平均消费...

最终答案:高价值用户群体特征为...
"""

八、安全与性能

8.1 SQL注入防护

# 使用参数化查询,禁止直接拼接SQL
from sqlalchemy import text

def safe_query(df: pd.DataFrame, filters: Dict) -> pd.DataFrame:
    """安全的查询执行"""
    result = df.copy()
    
    for col, condition in filters.items():
        if col not in df.columns:
            continue  # 忽略不存在的列
        
        operator = condition.get('op')
        value = condition.get('value')
        
        # 白名单验证操作符
        allowed_ops = ['eq', 'gt', 'gte', 'lt', 'lte', 'ne', 'in']
        if operator not in allowed_ops:
            continue
        
        # 类型安全检查
        if operator == 'in':
            if not isinstance(value, list):
                continue
            result = result[result[col].isin(value)]
        else:
            op_map = {
                'eq': '==', 'gt': '>', 'gte': '>=',
                'lt': '<', 'lte': '<=', 'ne': '!='
            }
            try:
                result = result.query(f"`{col}` {op_map[operator]} @value", local_dict={'value': value})
            except:
                pass
    
    return result

# 不使用:直接执行用户输入的SQL(危险!)
# pd.read_sql(user_input_sql, conn)

# 使用:只执行预定义的安全操作
result = safe_query(df, {'sales': {'op': 'gt', 'value': 1000}})

8.2 查询超时控制

import signal
from contextlib import contextmanager
from functools import wraps

class TimeoutException(Exception):
    pass

@contextmanager
def timeout(seconds):
    """超时上下文管理器"""
    def signal_handler(signum, frame):
        raise TimeoutException(f"查询执行超过{seconds}秒")
    
    signal.signal(signal.SIGALRM, signal_handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)

def with_timeout(seconds=30):
    """超时装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            try:
                with timeout(seconds):
                    return func(*args, **kwargs)
            except TimeoutException:
                return {'error': f'查询超时(>{seconds}秒),请优化查询条件'}
        return wrapper
    return decorator

@with_timeout(seconds=30)
def execute_analysis(df: pd.DataFrame, operation: str):
    """执行分析(带超时保护)"""
    # 耗时操作...
    return result

8.3 大数据集处理

class BigDataHandler:
    """大数据集处理策略"""
    
    def __init__(self, max_rows: int = 100000, sample_size: int = 10000):
        self.max_rows = max_rows
        self.sample_size = sample_size
    
    def load_safe(self, file_path: str) -> pd.DataFrame:
        """安全加载大文件"""
        # 检查文件大小
        file_size = Path(file_path).stat().st_size
        
        if file_size > 100 * 1024 * 1024:  # >100MB
            # 分块读取
            chunks = pd.read_csv(file_path, chunksize=self.max_rows)
            df = next(chunks)
            return df
        
        return pd.read_csv(file_path)
    
    def sample_for_preview(self, df: pd.DataFrame) -> pd.DataFrame:
        """采样用于预览"""
        if len(df) > self.sample_size:
            return df.sample(self.sample_size, random_state=42)
        return df
    
    def paginate_results(self, df: pd.DataFrame, page: int = 1, page_size: int = 100):
        """分页返回结果"""
        start = (page - 1) * page_size
        end = start + page_size
        return {
            'data': df.iloc[start:end].to_dict('records'),
            'page': page,
            'page_size': page_size,
            'total': len(df),
            'total_pages': (len(df) + page_size - 1) // page_size
        }

8.4 查询结果缓存

import hashlib
import json
from functools import lru_cache
import redis

class QueryCache:
    """查询结果缓存"""
    
    def __init__(self, redis_client=None):
        self.redis = redis_client
        self.memory_cache = {}
        self.ttl = 3600  # 1小时
    
    def _make_key(self, dataset_id: str, question: str) -> str:
        """生成缓存键"""
        key_data = f"{dataset_id}:{question}"
        return hashlib.md5(key_data.encode()).hexdigest()
    
    def get(self, dataset_id: str, question: str):
        """获取缓存"""
        key = self._make_key(dataset_id, question)
        
        if self.redis:
            cached = self.redis.get(key)
            if cached:
                return json.loads(cached)
        
        return self.memory_cache.get(key)
    
    def set(self, dataset_id: str, question: str, result: dict):
        """设置缓存"""
        key = self._make_key(dataset_id, question)
        
        if self.redis:
            self.redis.setex(key, self.ttl, json.dumps(result))
        else:
            self.memory_cache[key] = result
    
    def clear_dataset_cache(self, dataset_id: str):
        """清除数据集相关缓存"""
        if self.redis:
            # 使用模式匹配删除
            pattern = f"*{dataset_id}:*"
            for key in self.redis.scan_iter(match=pattern):
                self.redis.delete(key)
        else:
            keys_to_remove = [k for k in self.memory_cache if dataset_id in k]
            for k in keys_to_remove:
                del self.memory_cache[k]

九、部署上线

9.1 Docker配置

# Dockerfile
FROM python:3.11-slim

WORKDIR /app

# 安装依赖
RUN apt-get update && apt-get install -y \
    gcc \
    g++ \
    && rm -rf /var/lib/apt/lists/*

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# 复制代码
COPY app/ ./app/
COPY static/ ./static/

# 创建上传目录
RUN mkdir -p uploads static/charts static/reports

EXPOSE 8000

CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

# docker-compose.yml
version: '3.8'

services:
  api:
    build: .
    ports:
      - "8000:8000"
    volumes:
      - ./uploads:/app/uploads
      - ./static:/app/static
    environment:
      - OPENAI_API_KEY=${OPENAI_API_KEY}
      - DATABASE_URL=${DATABASE_URL}
    restart: unless-stopped
  
  frontend:
    build: ./frontend
    ports:
      - "80:80"
    depends_on:
      - api
    restart: unless-stopped

9.2 生产环境配置

# app/config.py
from pydantic_settings import BaseSettings
from functools import lru_cache

class Settings(BaseSettings):
    """应用配置"""
    # 基础配置
    APP_NAME: str = "AI数据分析助手"
    DEBUG: bool = False
    
    # 安全配置
    SECRET_KEY: str = "your-secret-key-change-in-production"
    ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
    
    # 数据库配置
    DATABASE_URL: str = "sqlite:///./data_analyst.db"
    
    # LLM配置
    OPENAI_API_KEY: str = ""
    OPENAI_MODEL: str = "gpt-3.5-turbo"
    
    # 文件上传配置
    MAX_UPLOAD_SIZE: int = 100 * 1024 * 1024  # 100MB
    UPLOAD_DIR: str = "uploads"
    ALLOWED_EXTENSIONS: list = [".csv", ".xlsx", ".json", ".parquet"]
    
    # 性能配置
    MAX_QUERY_TIME: int = 30  # 秒
    MAX_RESULT_ROWS: int = 10000
    CACHE_TTL: int = 3600  # 秒
    
    class Config:
        env_file = ".env"

@lru_cache()
def get_settings():
    return Settings()

9.3 Nginx配置

# nginx.conf
server {
    listen 80;
    server_name data-analyst.example.com;
    
    # 前端静态文件
    location / {
        root /var/www/frontend;
        try_files $uri $uri/ /index.html;
    }
    
    # API代理
    location /api/ {
        proxy_pass http://api:8000/api/;
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        
        # 超时设置
        proxy_connect_timeout 30s;
        proxy_send_timeout 30s;
        proxy_read_timeout 30s;
    }
    
    # 静态文件代理
    location /static/ {
        proxy_pass http://api:8000/static/;
    }
    
    # 上传文件大小限制
    client_max_body_size 100M;
}

十、完整代码下载

📦 项目源码

  • 后端完整代码(FastAPI + 核心服务)
  • 前端完整代码(Vue3组件)
  • 示例数据集(sample_sales.csv)
  • Jupyter Notebook调试工具
  • 单元测试用例

🗂️ 示例数据:sample_sales.csv

order_id,customer_id,order_date,category,product,quantity,unit_price,sales,region
ORD001,CUS001,2024-01-15,电子产品,手机,2,2999,5998,华东
ORD002,CUS002,2024-01-16,服装,T恤,3,99,297,华北
ORD003,CUS001,2024-01-17,电子产品,耳机,1,599,599,华东
ORD004,CUS003,2024-01-18,食品,零食,5,29,145,华南
ORD005,CUS002,2024-01-19,家居,台灯,2,199,398,华北
...

🧪 单元测试示例

# tests/test_data_loader.py
import pytest
import pandas as pd
from app.services.data_loader import DataLoader

def test_load_csv():
    loader = DataLoader()
    df, summary = loader.load("tests/sample.csv")
    
    assert isinstance(df, pd.DataFrame)
    assert summary['row_count'] > 0
    assert len(summary['columns']) > 0

def test_type_inference():
    loader = DataLoader()
    df = pd.DataFrame({
        'date': ['2024-01-01', '2024-01-02'],
        'amount': ['100.5', '200.5'],
        'category': ['A', 'B']
    })
    
    result = loader._infer_types(df)
    
    assert pd.api.types.is_datetime64_any_dtype(result['date'])
    assert pd.api.types.is_numeric_dtype(result['amount'])

def test_schema_analysis():
    loader = DataLoader()
    df = pd.DataFrame({
        'id': range(100),
        'sales': [100 + i * 10 for i in range(100)]
    })
    
    schema = loader.get_schema(df)
    
    assert schema['total_rows'] == 100
    assert len(schema['columns']) == 2

十一、总结

本案例实现了一个完整的AI数据分析助手,主要特点包括:

  • 零门槛:用户无需懂SQL或编程,用自然语言即可分析数据
  • 自动化:从数据理解到可视化到报告,全流程自动化
  • 智能化:基于LLM理解意图,推荐最佳分析方案
  • 可扩展:模块化设计,易于添加新功能和数据源
  • 安全可靠:多重安全机制保护数据和系统

通过本案例的学习,你可以掌握:

  1. 如何使用Pandas进行数据处理和类型推断
  2. 如何设计自然语言到代码的转换系统
  3. 如何基于数据特征自动推荐可视化方案
  4. 如何使用LangChain构建数据分析Agent
  5. 如何构建完整的AI应用并部署上线

🎯 课后练习

  1. 扩展支持更多数据源(数据库、API、云存储)
  2. 实现数据清洗和预处理功能
  3. 添加预测分析能力(时间序列预测)
  4. 实现多轮对话和上下文记忆
  5. 优化可视化推荐算法