云计算百科
云计算领域专业知识百科平台

米哈游Java面试被问:机器学习模型的在线服务和A/B测试

一、机器学习在线服务核心架构

1.1 统一模型服务架构

python

复制

下载

"""
机器学习模型在线服务架构
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass, asdict
import json
import time
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
import pickle
import numpy as np
import pandas as pd
from pydantic import BaseModel, Field, validator
import logging
from prometheus_client import Counter, Histogram, Gauge
from contextlib import contextmanager

# ==================== 监控指标 ====================
class ModelMetrics:
"""模型性能监控指标"""

# 请求相关指标
request_total = Counter('model_request_total', 'Total requests', ['model_name', 'version'])
request_duration = Histogram('model_request_duration_seconds', 'Request duration', ['model_name', 'version'])
request_errors = Counter('model_request_errors', 'Request errors', ['model_name', 'version', 'error_type'])

# 预测相关指标
prediction_latency = Histogram('prediction_latency_seconds', 'Prediction latency', ['model_name', 'version'])
prediction_accuracy = Gauge('prediction_accuracy', 'Prediction accuracy', ['model_name', 'version'])

# 资源相关指标
model_memory_usage = Gauge('model_memory_usage_bytes', 'Model memory usage', ['model_name', 'version'])
model_load_time = Gauge('model_load_time_seconds', 'Model load time', ['model_name', 'version'])

@dataclass
class ModelMetadata:
"""模型元数据"""
name: str
version: str
created_at: datetime
framework: str # tensorflow, pytorch, sklearn, xgboost等
input_schema: Dict[str, Any] # 输入数据模式
output_schema: Dict[str, Any] # 输出数据模式
performance_metrics: Dict[str, float] # 性能指标
feature_importance: Optional[List[float]] = None
dependencies: Optional[List[str]] = None
description: Optional[str] = None

def to_dict(self) -> Dict[str, Any]:
return asdict(self)

class ModelStatus(Enum):
"""模型状态枚举"""
LOADING = "loading"
READY = "ready"
UNHEALTHY = "unhealthy"
OFFLINE = "offline"

class ModelType(Enum):
"""模型类型枚举"""
CLASSIFICATION = "classification"
REGRESSION = "regression"
CLUSTERING = "clustering"
RECOMMENDATION = "recommendation"
NLP = "nlp"
CV = "computer_vision"

# ==================== 基础模型接口 ====================
class BaseModelInterface(ABC):
"""基础模型接口"""

def __init__(self, model_name: str, model_version: str):
self.model_name = model_name
self.model_version = model_version
self.metadata: Optional[ModelMetadata] = None
self.status: ModelStatus = ModelStatus.LOADING
self.loaded_at: Optional[datetime] = None
self.metrics = ModelMetrics()
self.logger = logging.getLogger(f"{model_name}.{model_version}")

@abstractmethod
async def load(self, model_path: str) -> bool:
"""加载模型"""
pass

@abstractmethod
async def predict(self, inputs: Union[Dict, List, np.ndarray]) -> Dict[str, Any]:
"""模型预测"""
pass

@abstractmethod
async def batch_predict(self, inputs: List) -> List[Dict[str, Any]]:
"""批量预测"""
pass

@abstractmethod
def get_feature_names(self) -> List[str]:
"""获取特征名称"""
pass

async def health_check(self) -> Dict[str, Any]:
"""健康检查"""
return {
"status": self.status.value,
"model_name": self.model_name,
"model_version": self.model_version,
"loaded_at": self.loaded_at.isoformat() if self.loaded_at else None,
"memory_usage": self._get_memory_usage(),
"uptime": self._get_uptime()
}

def _get_memory_usage(self) -> Optional[int]:
"""获取内存使用量"""
# 具体实现取决于框架
return None

def _get_uptime(self) -> Optional[float]:
"""获取运行时间"""
if self.loaded_at:
return (datetime.now() – self.loaded_at).total_seconds()
return None

@contextmanager
def track_request(self):
"""请求跟踪上下文管理器"""
start_time = time.time()
try:
self.metrics.request_total.labels(
model_name=self.model_name,
version=self.model_version
).inc()
yield
except Exception as e:
self.metrics.request_errors.labels(
model_name=self.model_name,
version=self.model_version,
error_type=type(e).__name__
).inc()
raise
finally:
duration = time.time() – start_time
self.metrics.request_duration.labels(
model_name=self.model_name,
version=self.model_version
).observe(duration)

# ==================== 具体模型实现 ====================
class ScikitLearnModel(BaseModelInterface):
"""Scikit-learn模型实现"""

def __init__(self, model_name: str, model_version: str):
super().__init__(model_name, model_version)
self.model = None
self.feature_names: List[str] = []

async def load(self, model_path: str) -> bool:
"""加载模型"""
with self.track_request():
try:
with open(model_path, 'rb') as f:
data = pickle.load(f)

self.model = data['model']
self.feature_names = data.get('feature_names', [])
self.metadata = data.get('metadata')

self.status = ModelStatus.READY
self.loaded_at = datetime.now()

# 记录加载时间
self.metrics.model_load_time.labels(
model_name=self.model_name,
version=self.model_version
).set(time.time() – self.loaded_at.timestamp())

self.logger.info(f"模型加载成功: {self.model_name}:{self.model_version}")
return True

except Exception as e:
self.status = ModelStatus.UNHEALTHY
self.logger.error(f"模型加载失败: {e}")
return False

async def predict(self, inputs: Union[Dict, List, np.ndarray]) -> Dict[str, Any]:
"""单条预测"""
with self.track_request():
try:
start_time = time.time()

# 输入预处理
processed_inputs = self._preprocess(inputs)

# 执行预测
prediction = self.model.predict(processed_inputs)
probabilities = None

# 如果模型支持概率预测
if hasattr(self.model, 'predict_proba'):
probabilities = self.model.predict_proba(processed_inputs)

# 记录预测延迟
latency = time.time() – start_time
self.metrics.prediction_latency.labels(
model_name=self.model_name,
version=self.model_version
).observe(latency)

return {
"prediction": prediction.tolist() if hasattr(prediction, 'tolist') else prediction,
"probabilities": probabilities.tolist() if probabilities is not None else None,
"model_name": self.model_name,
"model_version": self.model_version,
"latency_ms": latency * 1000,
"timestamp": datetime.now().isoformat()
}

except Exception as e:
self.logger.error(f"预测失败: {e}")
raise

async def batch_predict(self, inputs: List) -> List[Dict[str, Any]]:
"""批量预测"""
results = []
for input_data in inputs:
try:
result = await self.predict(input_data)
results.append(result)
except Exception as e:
results.append({
"error": str(e),
"model_name": self.model_name,
"model_version": self.model_version,
"timestamp": datetime.now().isoformat()
})
return results

def get_feature_names(self) -> List[str]:
"""获取特征名称"""
return self.feature_names

def _preprocess(self, inputs: Union[Dict, List, np.ndarray]) -> np.ndarray:
"""输入数据预处理"""
if isinstance(inputs, dict):
# 字典转数组,按特征名称排序
if self.feature_names:
return np.array([inputs.get(feat, 0) for feat in self.feature_names]).reshape(1, -1)
else:
return np.array(list(inputs.values())).reshape(1, -1)
elif isinstance(inputs, list):
return np.array(inputs).reshape(1, -1)
elif isinstance(inputs, np.ndarray):
return inputs.reshape(1, -1) if inputs.ndim == 1 else inputs
else:
raise ValueError(f"不支持的输入类型: {type(inputs)}")

class TensorFlowModel(BaseModelInterface):
"""TensorFlow模型实现"""

def __init__(self, model_name: str, model_version: str):
super().__init__(model_name, model_version)
try:
import tensorflow as tf
self.tf = tf
self.model = None
except ImportError:
raise ImportError("TensorFlow未安装")

async def load(self, model_path: str) -> bool:
"""加载TensorFlow模型"""
with self.track_request():
try:
self.model = self.tf.keras.models.load_model(model_path)

# 加载元数据
metadata_path = model_path.replace('.h5', '_metadata.json').replace('.keras', '_metadata.json')
if os.path.exists(metadata_path):
with open(metadata_path, 'r') as f:
metadata_data = json.load(f)
self.metadata = ModelMetadata(**metadata_data)

self.status = ModelStatus.READY
self.loaded_at = datetime.now()

self.logger.info(f"TensorFlow模型加载成功: {self.model_name}:{self.model_version}")
return True

except Exception as e:
self.status = ModelStatus.UNHEALTHY
self.logger.error(f"TensorFlow模型加载失败: {e}")
return False

async def predict(self, inputs: Union[Dict, List, np.ndarray]) -> Dict[str, Any]:
"""TensorFlow模型预测"""
with self.track_request():
try:
start_time = time.time()

# 输入预处理
processed_inputs = self._preprocess(inputs)

# 执行预测
prediction = self.model.predict(processed_inputs, verbose=0)

# 记录预测延迟
latency = time.time() – start_time
self.metrics.prediction_latency.labels(
model_name=self.model_name,
version=self.model_version
).observe(latency)

return {
"prediction": prediction.tolist(),
"model_name": self.model_name,
"model_version": self.model_version,
"latency_ms": latency * 1000,
"timestamp": datetime.now().isoformat()
}

except Exception as e:
self.logger.error(f"TensorFlow预测失败: {e}")
raise

# … 其他方法实现类似ScikitLearnModel

class PyTorchModel(BaseModelInterface):
"""PyTorch模型实现"""

def __init__(self, model_name: str, model_version: str):
super().__init__(model_name, model_version)
try:
import torch
self.torch = torch
self.model = None
self.device = self.torch.device('cuda' if self.torch.cuda.is_available() else 'cpu')
except ImportError:
raise ImportError("PyTorch未安装")

async def load(self, model_path: str) -> bool:
"""加载PyTorch模型"""
with self.track_request():
try:
self.model = self.torch.load(model_path, map_location=self.device)
self.model.to(self.device)
self.model.eval()

# 加载元数据
metadata_path = model_path.replace('.pt', '_metadata.json').replace('.pth', '_metadata.json')
if os.path.exists(metadata_path):
with open(metadata_path, 'r') as f:
metadata_data = json.load(f)
self.metadata = ModelMetadata(**metadata_data)

self.status = ModelStatus.READY
self.loaded_at = datetime.now()

self.logger.info(f"PyTorch模型加载成功: {self.model_name}:{self.model_version}")
return True

except Exception as e:
self.status = ModelStatus.UNHEALTHY
self.logger.error(f"PyTorch模型加载失败: {e}")
return False

async def predict(self, inputs: Union[Dict, List, np.ndarray]) -> Dict[str, Any]:
"""PyTorch模型预测"""
with self.track_request():
try:
start_time = time.time()

# 输入预处理
processed_inputs = self._preprocess(inputs)

# 转换为Tensor
inputs_tensor = self.torch.from_numpy(processed_inputs).float().to(self.device)

# 执行预测(不计算梯度)
with self.torch.no_grad():
prediction = self.model(inputs_tensor)

# 转换回numpy
prediction_np = prediction.cpu().numpy()

# 记录预测延迟
latency = time.time() – start_time
self.metrics.prediction_latency.labels(
model_name=self.model_name,
version=self.model_version
).observe(latency)

return {
"prediction": prediction_np.tolist(),
"model_name": self.model_name,
"model_version": self.model_version,
"latency_ms": latency * 1000,
"timestamp": datetime.now().isoformat()
}

except Exception as e:
self.logger.error(f"PyTorch预测失败: {e}")
raise

# … 其他方法实现类似ScikitLearnModel

# ==================== 模型工厂 ====================
class ModelFactory:
"""模型工厂"""

_model_registry = {
'sklearn': ScikitLearnModel,
'tensorflow': TensorFlowModel,
'pytorch': PyTorchModel,
'xgboost': ScikitLearnModel, # 复用sklearn接口
'lightgbm': ScikitLearnModel,
}

@classmethod
def create_model(cls, framework: str, model_name: str, model_version: str) -> BaseModelInterface:
"""创建模型实例"""
if framework not in cls._model_registry:
raise ValueError(f"不支持的框架: {framework}")

model_class = cls._model_registry[framework]
return model_class(model_name, model_version)

@classmethod
def register_framework(cls, framework: str, model_class):
"""注册新的模型框架"""
if not issubclass(model_class, BaseModelInterface):
raise TypeError("模型类必须继承自BaseModelInterface")
cls._model_registry[framework] = model_class

# ==================== 模型管理器 ====================
class ModelManager:
"""模型管理器"""

def __init__(self, config: Dict[str, Any]):
self.config = config
self.models: Dict[str, BaseModelInterface] = {}
self.model_registry: Dict[str, Dict[str, Any]] = {}
self.logger = logging.getLogger("ModelManager")
self.executor = ThreadPoolExecutor(max_workers=config.get('max_workers', 10))

# 初始化模型注册表
self._init_model_registry()

def _init_model_registry(self):
"""初始化模型注册表"""
# 可以从数据库或配置文件中加载
default_models = self.config.get('models', [])
for model_config in default_models:
model_key = f"{model_config['name']}:{model_config['version']}"
self.model_registry[model_key] = model_config

async def load_model(self, model_name: str, model_version: str) -> bool:
"""加载模型"""
model_key = f"{model_name}:{model_version}"

if model_key in self.models:
self.logger.info(f"模型已加载: {model_key}")
return True

if model_key not in self.model_registry:
self.logger.error(f"模型未注册: {model_key}")
return False

model_config = self.model_registry[model_key]
framework = model_config.get('framework', 'sklearn')
model_path = model_config.get('model_path')

if not model_path:
self.logger.error(f"模型路径未配置: {model_key}")
return False

try:
# 创建模型实例
model = ModelFactory.create_model(framework, model_name, model_version)

# 加载模型
success = await model.load(model_path)

if success:
self.models[model_key] = model
self.logger.info(f"模型加载成功: {model_key}")
return True
else:
self.logger.error(f"模型加载失败: {model_key}")
return False

except Exception as e:
self.logger.error(f"模型加载异常: {model_key}, 错误: {e}")
return False

async def unload_model(self, model_name: str, model_version: str) -> bool:
"""卸载模型"""
model_key = f"{model_name}:{model_version}"

if model_key not in self.models:
self.logger.warning(f"模型未加载: {model_key}")
return False

# 清理模型资源
del self.models[model_key]

# 触发垃圾回收
import gc
gc.collect()

self.logger.info(f"模型卸载成功: {model_key}")
return True

async def predict(
self,
model_name: str,
model_version: str,
inputs: Union[Dict, List, np.ndarray]
) -> Dict[str, Any]:
"""模型预测"""
model_key = f"{model_name}:{model_version}"

if model_key not in self.models:
# 尝试加载模型
loaded = await self.load_model(model_name, model_version)
if not loaded:
raise ValueError(f"模型未加载且加载失败: {model_key}")

model = self.models[model_key]
return await model.predict(inputs)

async def batch_predict(
self,
model_name: str,
model_version: str,
inputs: List
) -> List[Dict[str, Any]]:
"""批量预测"""
model_key = f"{model_name}:{model_version}"

if model_key not in self.models:
loaded = await self.load_model(model_name, model_version)
if not loaded:
raise ValueError(f"模型未加载且加载失败: {model_key}")

model = self.models[model_key]
return await model.batch_predict(inputs)

def list_models(self) -> List[Dict[str, Any]]:
"""列出所有模型"""
result = []
for model_key, model in self.models.items():
result.append({
'model_key': model_key,
'status': model.status.value,
'loaded_at': model.loaded_at.isoformat() if model.loaded_at else None,
'metadata': model.metadata.to_dict() if model.metadata else None
})
return result

async def health_check_all(self) -> Dict[str, Any]:
"""检查所有模型健康状态"""
health_status = {}

for model_key, model in self.models.items():
try:
health = await model.health_check()
health_status[model_key] = health
except Exception as e:
health_status[model_key] = {
'status': 'error',
'error': str(e)
}

return {
'timestamp': datetime.now().isoformat(),
'total_models': len(self.models),
'healthy_models': sum(1 for h in health_status.values()
if h.get('status') == 'ready'),
'details': health_status
}

# ==================== 特征工程服务 ====================
class FeatureEngineeringService:
"""特征工程服务"""

def __init__(self):
self.feature_pipelines: Dict[str, Any] = {}
self.feature_store = {} # 简单的内存特征存储,生产环境应使用Redis等
self.logger = logging.getLogger("FeatureEngineeringService")

async def transform(self, pipeline_name: str, raw_data: Dict[str, Any]) -> Dict[str, Any]:
"""特征转换"""
if pipeline_name not in self.feature_pipelines:
raise ValueError(f"特征管道未找到: {pipeline_name}")

pipeline = self.feature_pipelines[pipeline_name]

try:
# 执行特征转换
transformed_features = pipeline.transform(raw_data)

# 记录特征使用
await self._log_feature_usage(pipeline_name, transformed_features)

return transformed_features

except Exception as e:
self.logger.error(f"特征转换失败: {pipeline_name}, 错误: {e}")
raise

async def batch_transform(self, pipeline_name: str, raw_data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""批量特征转换"""
results = []
for raw_data in raw_data_list:
try:
transformed = await self.transform(pipeline_name, raw_data)
results.append(transformed)
except Exception as e:
results.append({"error": str(e)})
return results

def register_pipeline(self, pipeline_name: str, pipeline: Any):
"""注册特征管道"""
self.feature_pipelines[pipeline_name] = pipeline
self.logger.info(f"特征管道注册成功: {pipeline_name}")

async def _log_feature_usage(self, pipeline_name: str, features: Dict[str, Any]):
"""记录特征使用情况"""
# 这里可以记录到监控系统或数据库
feature_stats = {
'pipeline_name': pipeline_name,
'feature_count': len(features),
'timestamp': datetime.now().isoformat(),
'features': list(features.keys())
}

# 简单内存存储,生产环境应使用持久化存储
key = f"feature_usage:{datetime.now().strftime('%Y%m%d')}:{pipeline_name}"
if key not in self.feature_store:
self.feature_store[key] = []
self.feature_store[key].append(feature_stats)

篇幅限制下面就只能给大家展示小册部分内容了。整理了一份核心面试笔记包括了:Java面试、Spring、JVM、MyBatis、Redis、MySQL、并发编程、微服务、Linux、Springboot、SpringCloud、MQ、Kafc

需要全套面试笔记及答案 【点击此处即可/免费获取】​​​

1.2 高性能模型服务API

python

复制

下载

"""
FastAPI模型服务API
"""
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks, Query, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import uvicorn
from typing import List, Optional
import json
import asyncio
from datetime import datetime

# Pydantic数据模型
class PredictionRequest(BaseModel):
"""预测请求"""
model_name: str = Field(…, description="模型名称")
model_version: str = Field(…, description="模型版本")
data: Union[Dict[str, Any], List[Any]] = Field(…, description="输入数据")
request_id: Optional[str] = Field(None, description="请求ID")
features: Optional[Dict[str, Any]] = Field(None, description="特征数据")

@validator('data')
def validate_data(cls, v):
if not v:
raise ValueError('数据不能为空')
return v

class BatchPredictionRequest(BaseModel):
"""批量预测请求"""
model_name: str
model_version: str
data: List[Union[Dict[str, Any], List[Any]]]
request_id: Optional[str] = None

class ModelLoadRequest(BaseModel):
"""模型加载请求"""
model_name: str
model_version: str
model_path: str
framework: str = "sklearn"

class ABTestRequest(BaseModel):
"""A/B测试请求"""
user_id: str
experiment_name: str
features: Dict[str, Any]
context: Optional[Dict[str, Any]] = None

class PredictionResponse(BaseModel):
"""预测响应"""
request_id: Optional[str]
prediction: Any
probabilities: Optional[List[float]] = None
model_name: str
model_version: str
latency_ms: float
timestamp: str
experiment_group: Optional[str] = None # A/B测试分组

class HealthResponse(BaseModel):
"""健康检查响应"""
status: str
timestamp: str
uptime_seconds: float
model_count: int
healthy_model_count: int

# 创建FastAPI应用
app = FastAPI(
title="ML Model Serving API",
description="机器学习模型在线服务API",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)

# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# 全局变量
model_manager: Optional[ModelManager] = None
feature_service: Optional[FeatureEngineeringService] = None
ab_test_manager: Optional['ABTestManager'] = None # 将在后面定义

# 依赖注入
async def get_model_manager():
"""获取模型管理器"""
return model_manager

async def get_feature_service():
"""获取特征服务"""
return feature_service

async def get_ab_test_manager():
"""获取A/B测试管理器"""
return ab_test_manager

# 启动事件
@app.on_event("startup")
async def startup_event():
"""应用启动事件"""
global model_manager, feature_service, ab_test_manager

# 加载配置
with open('config/model_service_config.json', 'r') as f:
config = json.load(f)

# 初始化服务
model_manager = ModelManager(config)
feature_service = FeatureEngineeringService()

# 初始化A/B测试管理器
ab_test_manager = ABTestManager(config.get('ab_test', {}))

# 加载默认模型
for model_config in config.get('models', []):
await model_manager.load_model(
model_config['name'],
model_config['version']
)

logger.info("ML Model Serving API启动完成")

@app.on_event("shutdown")
async def shutdown_event():
"""应用关闭事件"""
logger.info("正在关闭ML Model Serving API…")

# 清理资源
if model_manager:
model_manager.executor.shutdown(wait=True)

# API路由
@app.get("/")
async def root():
"""根路由"""
return {
"service": "ML Model Serving API",
"version": "1.0.0",
"status": "running",
"timestamp": datetime.now().isoformat()
}

@app.get("/health")
async def health_check(
mm: ModelManager = Depends(get_model_manager)
) -> HealthResponse:
"""健康检查"""
if not mm:
raise HTTPException(status_code=503, detail="ModelManager未初始化")

health_status = await mm.health_check_all()

return HealthResponse(
status="healthy" if health_status['healthy_models'] > 0 else "unhealthy",
timestamp=datetime.now().isoformat(),
uptime_seconds=health_status.get('uptime', 0),
model_count=health_status['total_models'],
healthy_model_count=health_status['healthy_models']
)

@app.get("/models")
async def list_models(
mm: ModelManager = Depends(get_model_manager)
):
"""列出所有已加载的模型"""
if not mm:
raise HTTPException(status_code=503, detail="ModelManager未初始化")

return {
"timestamp": datetime.now().isoformat(),
"models": mm.list_models()
}

@app.post("/models/load")
async def load_model(
request: ModelLoadRequest,
mm: ModelManager = Depends(get_model_manager),
background_tasks: BackgroundTasks = None
):
"""加载模型"""
if not mm:
raise HTTPException(status_code=503, detail="ModelManager未初始化")

# 异步加载模型
if background_tasks:
background_tasks.add_task(
mm.load_model,
request.model_name,
request.model_version
)
return {"message": "模型加载任务已提交", "request_id": request.request_id}
else:
success = await mm.load_model(request.model_name, request.model_version)
if success:
return {"message": "模型加载成功", "request_id": request.request_id}
else:
raise HTTPException(status_code=500, detail="模型加载失败")

@app.post("/predict")
async def predict(
request: PredictionRequest,
mm: ModelManager = Depends(get_model_manager),
fs: FeatureEngineeringService = Depends(get_feature_service),
atm: 'ABTestManager' = Depends(get_ab_test_manager),
x_request_id: Optional[str] = Header(None, alias="X-Request-ID")
):
"""单条预测"""
start_time = time.time()
request_id = request.request_id or x_request_id or str(uuid.uuid4())

try:
# 1. 特征工程(如果有的话)
features = request.features
if not features and fs and request.data:
# 如果提供了原始数据但没有特征,尝试特征转换
# 这里需要根据实际情况实现
pass

# 2. 执行预测
prediction_result = await mm.predict(
request.model_name,
request.model_version,
features or request.data
)

# 3. 如果是A/B测试,记录分组信息
experiment_group = None
if atm and request_id:
experiment_group = atm.get_experiment_group(request_id)

# 4. 记录预测日志
await log_prediction(
request_id=request_id,
model_name=request.model_name,
model_version=request.model_version,
features=features,
prediction=prediction_result,
latency=time.time() – start_time
)

# 5. 返回结果
return PredictionResponse(
request_id=request_id,
prediction=prediction_result.get("prediction"),
probabilities=prediction_result.get("probabilities"),
model_name=prediction_result.get("model_name"),
model_version=prediction_result.get("model_version"),
latency_ms=prediction_result.get("latency_ms", 0),
timestamp=prediction_result.get("timestamp"),
experiment_group=experiment_group
)

except Exception as e:
logger.error(f"预测失败: {request_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))

@app.post("/batch_predict")
async def batch_predict(
request: BatchPredictionRequest,
mm: ModelManager = Depends(get_model_manager)
):
"""批量预测"""
if not mm:
raise HTTPException(status_code=503, detail="ModelManager未初始化")

try:
results = await mm.batch_predict(
request.model_name,
request.model_version,
request.data
)

return {
"request_id": request.request_id,
"timestamp": datetime.now().isoformat(),
"total": len(results),
"success": sum(1 for r in results if "error" not in r),
"failed": sum(1 for r in results if "error" in r),
"results": results
}

except Exception as e:
logger.error(f"批量预测失败: {e}")
raise HTTPException(status_code=500, detail=str(e))

@app.post("/ab_test/predict")
async def ab_test_predict(
request: ABTestRequest,
mm: ModelManager = Depends(get_model_manager),
atm: 'ABTestManager' = Depends(get_ab_test_manager)
):
"""A/B测试预测"""
if not mm or not atm:
raise HTTPException(status_code=503, detail="服务未初始化")

request_id = str(uuid.uuid4())

try:
# 1. 获取A/B测试分配
assignment = atm.assign_experiment(
user_id=request.user_id,
experiment_name=request.experiment_name,
context=request.context
)

if not assignment:
raise HTTPException(status_code=400, detail="A/B测试分配失败")

# 2. 根据分组选择模型
model_name = assignment.get('model_name')
model_version = assignment.get('model_version')

if not model_name or not model_version:
raise HTTPException(status_code=400, detail="A/B测试配置错误")

# 3. 执行预测
prediction_result = await mm.predict(
model_name,
model_version,
request.features
)

# 4. 记录A/B测试事件
atm.log_event(
request_id=request_id,
user_id=request.user_id,
experiment_name=request.experiment_name,
group_name=assignment.get('group_name'),
action='prediction',
metadata={
'features': request.features,
'prediction': prediction_result,
'context': request.context
}
)

# 5. 返回结果
return PredictionResponse(
request_id=request_id,
prediction=prediction_result.get("prediction"),
probabilities=prediction_result.get("probabilities"),
model_name=prediction_result.get("model_name"),
model_version=prediction_result.get("model_version"),
latency_ms=prediction_result.get("latency_ms", 0),
timestamp=prediction_result.get("timestamp"),
experiment_group=assignment.get('group_name')
)

except Exception as e:
logger.error(f"A/B测试预测失败: {e}")
raise HTTPException(status_code=500, detail=str(e))

@app.get("/ab_test/experiments")
async def list_experiments(
atm: 'ABTestManager' = Depends(get_ab_test_manager)
):
"""列出所有A/B测试实验"""
if not atm:
raise HTTPException(status_code=503, detail="ABTestManager未初始化")

experiments = atm.list_experiments()

return {
"timestamp": datetime.now().isoformat(),
"total": len(experiments),
"experiments": experiments
}

@app.get("/ab_test/experiment/{experiment_name}/stats")
async def get_experiment_stats(
experiment_name: str,
atm: 'ABTestManager' = Depends(get_ab_test_manager)
):
"""获取实验统计信息"""
if not atm:
raise HTTPException(status_code=503, detail="ABTestManager未初始化")

stats = atm.get_experiment_stats(experiment_name)

if not stats:
raise HTTPException(status_code=404, detail="实验不存在")

return stats

async def log_prediction(
request_id: str,
model_name: str,
model_version: str,
features: Any,
prediction: Any,
latency: float
):
"""记录预测日志"""
# 这里可以实现日志记录逻辑
# 可以记录到文件、数据库或监控系统
log_entry = {
"request_id": request_id,
"timestamp": datetime.now().isoformat(),
"model": f"{model_name}:{model_version}",
"features": features,
"prediction": prediction,
"latency": latency,
"type": "prediction"
}

# 异步写入日志
asyncio.create_task(_write_log(log_entry))

async def _write_log(log_entry: Dict[str, Any]):
"""写入日志(异步)"""
try:
# 这里可以实现实际的日志写入逻辑
# 例如写入文件、数据库或发送到日志服务
logger.info(f"预测日志: {json.dumps(log_entry)}")
except Exception as e:
logger.error(f"日志写入失败: {e}")

# 中间件:请求ID注入
@app.middleware("http")
async def add_request_id(request, call_next):
"""添加请求ID中间件"""
request_id = request.headers.get('X-Request-ID') or str(uuid.uuid4())

# 将请求ID添加到请求状态
request.state.request_id = request_id

# 处理请求
start_time = time.time()
response = await call_next(request)
process_time = time.time() – start_time

# 添加请求ID到响应头
response.headers['X-Request-ID'] = request_id
response.headers['X-Process-Time'] = str(process_time)

return response

# 错误处理
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
"""全局异常处理器"""
request_id = getattr(request.state, 'request_id', 'unknown')

logger.error(f"请求异常: {request_id}, 错误: {exc}")

return JSONResponse(
status_code=500,
content={
"request_id": request_id,
"error": str(exc),
"timestamp": datetime.now().isoformat()
}
)

def start_server(host: str = "0.0.0.0", port: int = 8000):
"""启动服务器"""
uvicorn.run(
app,
host=host,
port=port,
log_level="info",
access_log=True
)

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="ML Model Serving API")
parser.add_argument("–host", default="0.0.0.0", help="服务器地址")
parser.add_argument("–port", type=int, default=8000, help="服务器端口")

args = parser.parse_args()
start_server(args.host, args.port)

二、A/B测试系统实现

2.1 A/B测试核心架构

python

复制

下载

"""
A/B测试系统核心实现
"""
import hashlib
import random
import json
from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
import asyncio
from collections import defaultdict
import statistics
import numpy as np
from scipy import stats # 用于统计检验

# ==================== 数据模型 ====================
class ExperimentStatus(Enum):
"""实验状态"""
DRAFT = "draft" # 草稿
RUNNING = "running" # 运行中
PAUSED = "paused" # 暂停
STOPPED = "stopped" # 停止
COMPLETED = "completed" # 完成

class AssignmentAlgorithm(Enum):
"""分配算法"""
RANDOM = "random" # 随机分配
HASH_BASED = "hash_based" # 基于哈希的分配
WEIGHTED = "weighted" # 加权分配
BANDIT = "bandit" # 多臂老虎机算法
CUSTOM = "custom" # 自定义算法

@dataclass
class ExperimentGroup:
"""实验分组"""
name: str # 分组名称
weight: float = 1.0 # 分配权重
model_name: Optional[str] = None # 使用的模型名称
model_version: Optional[str] = None # 使用的模型版本
parameters: Dict[str, Any] = field(default_factory=dict) # 自定义参数
description: Optional[str] = None

@property
def model_key(self) -> Optional[str]:
"""模型键值"""
if self.model_name and self.model_version:
return f"{self.model_name}:{self.model_version}"
return None

@dataclass
class Experiment:
"""A/B测试实验"""
name: str # 实验名称
description: Optional[str] = None # 实验描述
status: ExperimentStatus = ExperimentStatus.DRAFT # 实验状态
start_time: Optional[datetime] = None # 开始时间
end_time: Optional[datetime] = None # 结束时间
groups: List[ExperimentGroup] = field(default_factory=list) # 实验分组
assignment_algorithm: AssignmentAlgorithm = AssignmentAlgorithm.RANDOM # 分配算法
target_users: Optional[List[str]] = None # 目标用户
sample_rate: float = 1.0 # 采样率
metrics: List[str] = field(default_factory=list) # 监控指标
hypotheses: Optional[str] = None # 实验假设
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)

def get_group(self, group_name: str) -> Optional[ExperimentGroup]:
"""获取分组"""
for group in self.groups:
if group.name == group_name:
return group
return None

def is_running(self) -> bool:
"""是否正在运行"""
if self.status != ExperimentStatus.RUNNING:
return False

now = datetime.now()
if self.start_time and now < self.start_time:
return False
if self.end_time and now > self.end_time:
return False

return True

@dataclass
class Assignment:
"""用户分配"""
experiment_name: str # 实验名称
user_id: str # 用户ID
group_name: str # 分组名称
assigned_at: datetime # 分配时间
assignment_id: str # 分配ID
context: Optional[Dict[str, Any]] = None # 分配上下文

@dataclass
class Event:
"""实验事件"""
event_id: str # 事件ID
experiment_name: str # 实验名称
user_id: str # 用户ID
group_name: str # 分组名称
event_type: str # 事件类型
timestamp: datetime # 事件时间
metadata: Dict[str, Any] = field(default_factory=dict) # 事件元数据

@dataclass
class ExperimentStats:
"""实验统计"""
experiment_name: str
start_time: datetime
end_time: Optional[datetime]
total_users: int
total_events: int
group_stats: Dict[str, 'GroupStats'] # 分组统计
significance_test: Optional[Dict[str, Any]] = None # 显著性检验结果

@dataclass
class GroupStats:
"""分组统计"""
group_name: str
user_count: int
event_counts: Dict[str, int] # 事件计数
metric_values: Dict[str, List[float]] # 指标值
conversions: Dict[str, float] # 转化率

@property
def conversion_rate(self) -> float:
"""总体转化率"""
if 'conversion' in self.event_counts:
total_events = sum(self.event_counts.values())
if total_events > 0:
return self.event_counts['conversion'] / total_events
return 0.0

# ==================== 分配算法实现 ====================
class AssignmentAlgorithmBase(ABC):
"""分配算法基类"""

@abstractmethod
def assign(self, experiment: Experiment, user_id: str, context: Optional[Dict] = None) -> Optional[str]:
"""分配用户到分组"""
pass

class RandomAssignment(AssignmentAlgorithmBase):
"""随机分配算法"""

def assign(self, experiment: Experiment, user_id: str, context: Optional[Dict] = None) -> Optional[str]:
"""随机分配"""
if not experiment.groups:
return None

# 计算总权重
total_weight = sum(group.weight for group in experiment.groups)
if total_weight <= 0:
return None

# 随机选择
rand = random.random() * total_weight
cumulative = 0

for group in experiment.groups:
cumulative += group.weight
if rand <= cumulative:
return group.name

# 理论上不会执行到这里
return experiment.groups[-1].name

class HashBasedAssignment(AssignmentAlgorithmBase):
"""基于哈希的分配算法(确保一致性)"""

def assign(self, experiment: Experiment, user_id: str, context: Optional[Dict] = None) -> Optional[str]:
"""哈希分配"""
if not experiment.groups:
return None

# 使用用户ID和实验名称生成哈希
hash_input = f"{experiment.name}:{user_id}"
if context:
# 可以包含上下文信息
hash_input += f":{json.dumps(context, sort_keys=True)}"

# 生成哈希值
hash_val = int(hashlib.md5(hash_input.encode()).hexdigest(), 16)

# 计算总权重
total_weight = sum(group.weight for group in experiment.groups)
if total_weight <= 0:
return None

# 基于哈希值选择分组
hash_mod = hash_val % total_weight
cumulative = 0

for group in experiment.groups:
cumulative += group.weight
if hash_mod < cumulative:
return group.name

return experiment.groups[-1].name

class WeightedAssignment(AssignmentAlgorithmBase):
"""加权分配算法"""

def assign(self, experiment: Experiment, user_id: str, context: Optional[Dict] = None) -> Optional[str]:
"""加权分配"""
if not experiment.groups:
return None

# 获取历史分配数据(实际实现中应从数据库获取)
historical_assignments = self._get_historical_assignments(experiment.name)

# 计算每个分组的当前分配比例
group_counts = defaultdict(int)
for assignment in historical_assignments:
group_counts[assignment.group_name] += 1

total_assignments = sum(group_counts.values())

# 调整权重以实现目标分配比例
adjusted_weights = []
for group in experiment.groups:
expected_ratio = group.weight / sum(g.weight for g in experiment.groups)
actual_ratio = group_counts[group.name] / total_assignments if total_assignments > 0 else 0

# 如果实际比例低于预期,增加分配概率
adjustment = max(0.1, expected_ratio – actual_ratio + 1.0)
adjusted_weights.append(adjustment)

# 基于调整后的权重进行分配
total_adjusted = sum(adjusted_weights)
if total_adjusted <= 0:
return None

rand = random.random() * total_adjusted
cumulative = 0

for i, group in enumerate(experiment.groups):
cumulative += adjusted_weights[i]
if rand <= cumulative:
return group.name

return experiment.groups[-1].name

def _get_historical_assignments(self, experiment_name: str) -> List[Assignment]:
"""获取历史分配数据(简化实现)"""
# 实际实现中应从数据库查询
return []

class BanditAssignment(AssignmentAlgorithmBase):
"""多臂老虎机分配算法"""

def __init__(self, alpha: float = 1.0, beta: float = 1.0):
self.alpha = alpha # 成功先验参数
self.beta = beta # 失败先验参数
self.group_stats = defaultdict(lambda: {'success': 0, 'failure': 0})

def assign(self, experiment: Experiment, user_id: str, context: Optional[Dict] = None) -> Optional[str]:
"""Bandit分配"""
if not experiment.groups:
return None

# 如果某个分组还没有数据,优先探索
unexplored_groups = []
for group in experiment.groups:
stats = self.group_stats[group.name]
if stats['success'] + stats['failure'] == 0:
unexplored_groups.append(group.name)

if unexplored_groups:
# 随机选择一个未探索的分组
return random.choice(unexplored_groups)

# 基于Beta分布采样选择分组
samples = []
for group in experiment.groups:
stats = self.group_stats[group.name]
# 从Beta分布采样
sample = np.random.beta(
stats['success'] + self.alpha,
stats['failure'] + self.beta
)
samples.append((sample, group.name))

# 选择采样值最大的分组
samples.sort(reverse=True)
return samples[0][1]

def update(self, group_name: str, success: bool):
"""更新分组统计"""
if success:
self.group_stats[group_name]['success'] += 1
else:
self.group_stats[group_name]['failure'] += 1

# ==================== A/B测试管理器 ====================
class ABTestManager:
"""A/B测试管理器"""

def __init__(self, config: Dict[str, Any]):
self.config = config
self.experiments: Dict[str, Experiment] = {}
self.assignments: Dict[str, Assignment] = {} # assignment_id -> Assignment
self.events: List[Event] = []

# 分配算法注册表
self.algorithms = {
AssignmentAlgorithm.RANDOM: RandomAssignment(),
AssignmentAlgorithm.HASH_BASED: HashBasedAssignment(),
AssignmentAlgorithm.WEIGHTED: WeightedAssignment(),
AssignmentAlgorithm.BANDIT: BanditAssignment()
}

self.logger = logging.getLogger("ABTestManager")

# 加载实验配置
self._load_experiments()

def _load_experiments(self):
"""加载实验配置"""
experiments_config = self.config.get('experiments', [])

for exp_config in experiments_config:
experiment = Experiment(
name=exp_config['name'],
description=exp_config.get('description'),
status=ExperimentStatus(exp_config.get('status', 'draft')),
groups=[
ExperimentGroup(
name=g['name'],
weight=g.get('weight', 1.0),
model_name=g.get('model_name'),
model_version=g.get('model_version'),
parameters=g.get('parameters', {}),
description=g.get('description')
)
for g in exp_config.get('groups', [])
],
assignment_algorithm=AssignmentAlgorithm(
exp_config.get('assignment_algorithm', 'random')
),
target_users=exp_config.get('target_users'),
sample_rate=exp_config.get('sample_rate', 1.0),
metrics=exp_config.get('metrics', []),
hypotheses=exp_config.get('hypotheses')
)

self.experiments[experiment.name] = experiment
self.logger.info(f"加载实验: {experiment.name}")

def create_experiment(self, experiment: Experiment) -> bool:
"""创建实验"""
if experiment.name in self.experiments:
self.logger.error(f"实验已存在: {experiment.name}")
return False

# 验证实验配置
if not self._validate_experiment(experiment):
return False

self.experiments[experiment.name] = experiment
self.logger.info(f"创建实验: {experiment.name}")

# 保存到配置文件或数据库
self._save_experiments()

return True

def start_experiment(self, experiment_name: str) -> bool:
"""启动实验"""
if experiment_name not in self.experiments:
self.logger.error(f"实验不存在: {experiment_name}")
return False

experiment = self.experiments[experiment_name]

if experiment.status == ExperimentStatus.RUNNING:
self.logger.warning(f"实验已在运行: {experiment_name}")
return True

experiment.status = ExperimentStatus.RUNNING
experiment.start_time = datetime.now()
experiment.updated_at = datetime.now()

self.logger.info(f"启动实验: {experiment_name}")
self._save_experiments()

return True

def stop_experiment(self, experiment_name: str) -> bool:
"""停止实验"""
if experiment_name not in self.experiments:
self.logger.error(f"实验不存在: {experiment_name}")
return False

experiment = self.experiments[experiment_name]

if experiment.status == ExperimentStatus.STOPPED:
return True

experiment.status = ExperimentStatus.STOPPED
experiment.end_time = datetime.now()
experiment.updated_at = datetime.now()

self.logger.info(f"停止实验: {experiment_name}")
self._save_experiments()

return True

def assign_experiment(
self,
user_id: str,
experiment_name: str,
context: Optional[Dict[str, Any]] = None
) -> Optional[Dict[str, Any]]:
"""分配用户到实验分组"""
if experiment_name not in self.experiments:
self.logger.error(f"实验不存在: {experiment_name}")
return None

experiment = self.experiments[experiment_name]

# 检查实验状态
if not experiment.is_running():
self.logger.warning(f"实验未运行: {experiment_name}")
return None

# 检查目标用户
if experiment.target_users and user_id not in experiment.target_users:
self.logger.debug(f"用户不在目标列表中: {user_id}")
return None

# 检查采样率
if experiment.sample_rate < 1.0:
hash_input = f"{experiment_name}:{user_id}"
hash_val = int(hashlib.md5(hash_input.encode()).hexdigest(), 16)
if (hash_val % 1000) / 1000.0 > experiment.sample_rate:
self.logger.debug(f"用户未采样: {user_id}")
return None

# 获取分配算法
algorithm = self.algorithms.get(experiment.assignment_algorithm)
if not algorithm:
self.logger.error(f"不支持的分配算法: {experiment.assignment_algorithm}")
algorithm = RandomAssignment()

# 分配分组
group_name = algorithm.assign(experiment, user_id, context)
if not group_name:
self.logger.error(f"分配失败: {experiment_name}, {user_id}")
return None

group = experiment.get_group(group_name)
if not group:
self.logger.error(f"分组不存在: {group_name}")
return None

# 创建分配记录
assignment_id = hashlib.md5(
f"{experiment_name}:{user_id}:{datetime.now().isoformat()}".encode()
).hexdigest()

assignment = Assignment(
experiment_name=experiment_name,
user_id=user_id,
group_name=group_name,
assigned_at=datetime.now(),
assignment_id=assignment_id,
context=context
)

self.assignments[assignment_id] = assignment

# 记录分配事件
self.log_event(
request_id=assignment_id,
user_id=user_id,
experiment_name=experiment_name,
group_name=group_name,
action='assignment',
metadata={
'algorithm': experiment.assignment_algorithm.value,
'context': context
}
)

self.logger.debug(f"分配用户 {user_id} 到实验 {experiment_name} 分组 {group_name}")

return {
'experiment_name': experiment_name,
'group_name': group_name,
'assignment_id': assignment_id,
'model_name': group.model_name,
'model_version': group.model_version,
'parameters': group.parameters
}

def get_experiment_group(self, request_id: str) -> Optional[str]:
"""获取实验分组(通过请求ID)"""
# 在实际实现中,需要建立request_id到assignment的映射
# 这里简化实现,假设request_id就是assignment_id
if request_id in self.assignments:
assignment = self.assignments[request_id]
return assignment.group_name
return None

def log_event(
self,
request_id: str,
user_id: str,
experiment_name: str,
group_name: str,
action: str,
metadata: Optional[Dict[str, Any]] = None
):
"""记录实验事件"""
event = Event(
event_id=hashlib.md5(
f"{request_id}:{action}:{datetime.now().isoformat()}".encode()
).hexdigest(),
experiment_name=experiment_name,
user_id=user_id,
group_name=group_name,
event_type=action,
timestamp=datetime.now(),
metadata=metadata or {}
)

self.events.append(event)

# 如果是Bandit算法,更新统计
if experiment_name in self.experiments:
experiment = self.experiments[experiment_name]
if experiment.assignment_algorithm == AssignmentAlgorithm.BANDIT:
algorithm = self.algorithms.get(AssignmentAlgorithm.BANDIT)
if algorithm and isinstance(algorithm, BanditAssignment):
# 判断是否成功(根据业务逻辑)
success = self._is_success_event(action, metadata)
algorithm.update(group_name, success)

self.logger.debug(f"记录事件: {experiment_name}, {group_name}, {action}")

def _is_success_event(self, action: str, metadata: Optional[Dict]) -> bool:
"""判断事件是否成功"""
# 这里需要根据业务逻辑实现
# 例如:如果是购买事件,且金额大于0,则认为是成功
if action == 'purchase' and metadata:
amount = metadata.get('amount', 0)
return amount > 0
elif action == 'conversion':
return True
return False

def list_experiments(self) -> List[Dict[str, Any]]:
"""列出所有实验"""
result = []
for experiment in self.experiments.values():
result.append({
'name': experiment.name,
'description': experiment.description,
'status': experiment.status.value,
'start_time': experiment.start_time.isoformat() if experiment.start_time else None,
'end_time': experiment.end_time.isoformat() if experiment.end_time else None,
'groups': [
{
'name': g.name,
'weight': g.weight,
'model_name': g.model_name,
'model_version': g.model_version
}
for g in experiment.groups
],
'assignment_algorithm': experiment.assignment_algorithm.value,
'sample_rate': experiment.sample_rate,
'created_at': experiment.created_at.isoformat(),
'updated_at': experiment.updated_at.isoformat()
})
return result

def get_experiment_stats(self, experiment_name: str) -> Optional[Dict[str, Any]]:
"""获取实验统计信息"""
if experiment_name not in self.experiments:
return None

experiment = self.experiments[experiment_name]

# 收集分配数据
experiment_assignments = [
a for a in self.assignments.values()
if a.experiment_name == experiment_name
]

# 收集事件数据
experiment_events = [
e for e in self.events
if e.experiment_name == experiment_name
]

# 按分组统计
group_stats = {}
for group in experiment.groups:
group_assignments = [
a for a in experiment_assignments
if a.group_name == group.name
]

group_events = [
e for e in experiment_events
if e.group_name == group.name
]

# 计算事件计数
event_counts = defaultdict(int)
metric_values = defaultdict(list)

for event in group_events:
event_counts[event.event_type] += 1

# 提取指标值
for metric in experiment.metrics:
if metric in event.metadata:
value = event.metadata[metric]
if isinstance(value, (int, float)):
metric_values[metric].append(value)

# 计算转化率
conversions = {}
total_users = len(set(a.user_id for a in group_assignments))
if total_users > 0:
for event_type in event_counts:
# 计算每个用户的平均事件数
conversions[event_type] = event_counts[event_type] / total_users

group_stats[group.name] = GroupStats(
group_name=group.name,
user_count=total_users,
event_counts=dict(event_counts),
metric_values=dict(metric_values),
conversions=conversions
)

# 显著性检验
significance_test = None
if len(experiment.groups) >= 2:
significance_test = self._calculate_significance(
experiment, group_stats
)

stats = ExperimentStats(
experiment_name=experiment_name,
start_time=experiment.start_time or experiment.created_at,
end_time=experiment.end_time,
total_users=len(set(a.user_id for a in experiment_assignments)),
total_events=len(experiment_events),
group_stats=group_stats,
significance_test=significance_test
)

return self._format_stats(stats)

def _calculate_significance(
self,
experiment: Experiment,
group_stats: Dict[str, GroupStats]
) -> Dict[str, Any]:
"""计算显著性检验"""
if len(experiment.groups) < 2:
return None

# 选择控制组(通常第一个分组)
control_group = experiment.groups[0].name
treatment_groups = [g.name for g in experiment.groups[1:]]

results = {}

for treatment_group in treatment_groups:
# 获取转化率数据
control_stats = group_stats[control_group]
treatment_stats = group_stats[treatment_group]

# 检查是否有足够的数据
if control_stats.user_count == 0 or treatment_stats.user_count == 0:
results[treatment_group] = {
'significant': False,
'p_value': 1.0,
'effect_size': 0.0,
'error': 'Insufficient data'
}
continue

# 提取主要指标(假设第一个指标是主要指标)
primary_metric = experiment.metrics[0] if experiment.metrics else 'conversion'

# 获取指标值
control_values = control_stats.metric_values.get(primary_metric, [])
treatment_values = treatment_stats.metric_values.get(primary_metric, [])

if not control_values or not treatment_values:
# 如果没有连续值指标,使用二项分布检验
control_conversions = control_stats.event_counts.get('conversion', 0)
treatment_conversions = treatment_stats.event_counts.get('conversion', 0)

control_non_conversions = control_stats.user_count – control_conversions
treatment_non_conversions = treatment_stats.user_count – treatment_conversions

# 卡方检验
from scipy.stats import chi2_contingency
contingency_table = [
[control_conversions, control_non_conversions],
[treatment_conversions, treatment_non_conversions]
]

chi2, p_value, dof, expected = chi2_contingency(contingency_table)

# 计算效应大小
control_rate = control_conversions / control_stats.user_count if control_stats.user_count > 0 else 0
treatment_rate = treatment_conversions / treatment_stats.user_count if treatment_stats.user_count > 0 else 0
effect_size = treatment_rate – control_rate

results[treatment_group] = {
'significant': p_value < 0.05,
'p_value': float(p_value),
'effect_size': float(effect_size),
'relative_improvement': float(effect_size / control_rate) if control_rate > 0 else 0.0,
'test_type': 'chi_square'
}
else:
# t检验(连续值指标)
from scipy.stats import ttest_ind

t_stat, p_value = ttest_ind(control_values, treatment_values, equal_var=False)

# 计算效应大小(Cohen's d)
control_mean = np.mean(control_values)
treatment_mean = np.mean(treatment_values)
control_std = np.std(control_values, ddof=1)
treatment_std = np.std(treatment_values, ddof=1)

pooled_std = np.sqrt((control_std**2 + treatment_std**2) / 2)
effect_size = (treatment_mean – control_mean) / pooled_std if pooled_std > 0 else 0

results[treatment_group] = {
'significant': p_value < 0.05,
'p_value': float(p_value),
'effect_size': float(effect_size),
'relative_improvement': float((treatment_mean – control_mean) / control_mean) if control_mean > 0 else 0.0,
'test_type': 't_test'
}

return results

def _format_stats(self, stats: ExperimentStats) -> Dict[str, Any]:
"""格式化统计信息"""
return {
'experiment_name': stats.experiment_name,
'start_time': stats.start_time.isoformat(),
'end_time': stats.end_time.isoformat() if stats.end_time else None,
'total_users': stats.total_users,
'total_events': stats.total_events,
'groups': {
group_name: {
'user_count': group_stats.user_count,
'event_counts': group_stats.event_counts,
'conversions': group_stats.conversions,
'metric_means': {
metric: np.mean(values) if values else 0
for metric, values in group_stats.metric_values.items()
},
'metric_stds': {
metric: np.std(values) if values else 0
for metric, values in group_stats.metric_values.items()
}
}
for group_name, group_stats in stats.group_stats.items()
},
'significance_test': stats.significance_test,
'recommendation': self._generate_recommendation(stats)
}

def _generate_recommendation(self, stats: ExperimentStats) -> str:
"""生成实验建议"""
if not stats.significance_test:
return "Insufficient data for recommendation"

# 检查是否有显著提升的分组
best_group = None
best_improvement = 0

for treatment_group, test_result in stats.significance_test.items():
if test_result.get('significant', False):
improvement = test_result.get('relative_improvement', 0)
if improvement > best_improvement:
best_improvement = improvement
best_group = treatment_group

if best_group and best_improvement > 0:
return f"Recommend implementing {best_group} (improvement: {best_improvement:.2%})"
elif best_group and best_improvement < 0:
return f"Warning: {best_group} performs worse than control"
else:
return "No significant difference detected"

def _validate_experiment(self, experiment: Experiment) -> bool:
"""验证实验配置"""
if not experiment.name:
self.logger.error("实验名称不能为空")
return False

if not experiment.groups:
self.logger.error("实验必须包含至少一个分组")
return False

# 检查分组名称是否唯一
group_names = [g.name for g in experiment.groups]
if len(group_names) != len(set(group_names)):
self.logger.error("分组名称必须唯一")
return False

# 检查权重是否合理
total_weight = sum(g.weight for g in experiment.groups)
if total_weight <= 0:
self.logger.error("分组权重总和必须大于0")
return False

# 检查采样率
if not 0 <= experiment.sample_rate <= 1:
self.logger.error("采样率必须在0到1之间")
return False

return True

def _save_experiments(self):
"""保存实验配置"""
# 实际实现中应保存到数据库或配置文件
experiments_data = []
for experiment in self.experiments.values():
experiments_data.append({
'name': experiment.name,
'description': experiment.description,
'status': experiment.status.value,
'groups': [
{
'name': g.name,
'weight': g.weight,
'model_name': g.model_name,
'model_version': g.model_version,
'parameters': g.parameters,
'description': g.description
}
for g in experiment.groups
],
'assignment_algorithm': experiment.assignment_algorithm.value,
'target_users': experiment.target_users,
'sample_rate': experiment.sample_rate,
'metrics': experiment.metrics,
'hypotheses': experiment.hypotheses
})

# 保存到文件(示例)
config_file = self.config.get('experiments_file', 'experiments.json')
try:
with open(config_file, 'w') as f:
json.dump(experiments_data, f, indent=2, ensure_ascii=False)
except Exception as e:
self.logger.error(f"保存实验配置失败: {e}")

# ==================== 实验监控面板 ====================
class ExperimentDashboard:
"""实验监控面板"""

def __init__(self, ab_test_manager: ABTestManager):
self.ab_test_manager = ab_test_manager
self.logger = logging.getLogger("ExperimentDashboard")

async def get_dashboard_data(self) -> Dict[str, Any]:
"""获取面板数据"""
experiments = self.ab_test_manager.list_experiments()

dashboard_data = {
'timestamp': datetime.now().isoformat(),
'total_experiments': len(experiments),
'running_experiments': sum(1 for e in experiments if e['status'] == 'running'),
'experiments': []
}

for experiment in experiments:
experiment_name = experiment['name']
stats = self.ab_test_manager.get_experiment_stats(experiment_name)

dashboard_data['experiments'].append({
'info': experiment,
'stats': stats
})

return dashboard_data

def generate_report(self, experiment_name: str) -> Dict[str, Any]:
"""生成实验报告"""
stats = self.ab_test_manager.get_experiment_stats(experiment_name)
if not stats:
return {'error': 'Experiment not found'}

experiment = self.ab_test_manager.experiments.get(experiment_name)

report = {
'experiment_name': experiment_name,
'report_date': datetime.now().isoformat(),
'experiment_info': {
'description': experiment.description if experiment else None,
'hypotheses': experiment.hypotheses if experiment else None,
'duration_days': None
},
'executive_summary': self._generate_executive_summary(stats),
'methodology': {
'assignment_algorithm': experiment.assignment_algorithm.value if experiment else 'unknown',
'sample_rate': experiment.sample_rate if experiment else 1.0,
'target_users': experiment.target_users if experiment else None
},
'results': self._format_results_for_report(stats),
'conclusions': self._generate_conclusions(stats),
'recommendations': stats.get('recommendation', 'No recommendation')
}

if experiment and experiment.start_time:
duration = (datetime.now() – experiment.start_time).days
report['experiment_info']['duration_days'] = duration

return report

def _generate_executive_summary(self, stats: Dict[str, Any]) -> str:
"""生成执行摘要"""
total_users = stats.get('total_users', 0)
total_events = stats.get('total_events', 0)

summary = f"""
Experiment: {stats.get('experiment_name')}
Total Users: {total_users}
Total Events: {total_events}
"""

significance_test = stats.get('significance_test')
if significance_test:
summary += "\\nSignificance Test Results:\\n"
for group, result in significance_test.items():
if result.get('significant', False):
improvement = result.get('relative_improvement', 0)
summary += f" – {group}: Significant improvement of {improvement:.2%}\\n"
else:
summary += f" – {group}: No significant difference\\n"

return summary

def _format_results_for_report(self, stats: Dict[str, Any]) -> Dict[str, Any]:
"""格式化结果用于报告"""
formatted_results = {
'overall_metrics': {
'total_users': stats.get('total_users', 0),
'total_events': stats.get('total_events', 0)
},
'group_performance': {}
}

groups = stats.get('groups', {})
for group_name, group_stats in groups.items():
formatted_results['group_performance'][group_name] = {
'user_count': group_stats.get('user_count', 0),
'conversion_rate': group_stats.get('conversions', {}).get('conversion', 0),
'key_metrics': {
metric: {
'mean': group_stats.get('metric_means', {}).get(metric, 0),
'std': group_stats.get('metric_stds', {}).get(metric, 0)
}
for metric in group_stats.get('metric_means', {}).keys()
}
}

return formatted_results

def _generate_conclusions(self, stats: Dict[str, Any]) -> List[str]:
"""生成结论"""
conclusions = []

significance_test = stats.get('significance_test')
if not significance_test:
conclusions.append("实验数据不足,无法得出结论")
return conclusions

for group, result in significance_test.items():
if result.get('significant', False):
improvement = result.get('relative_improvement', 0)
p_value = result.get('p_value', 1.0)
conclusions.append(
f"分组 {group} 相对于控制组有显著改善 (p={p_value:.4f}, 提升={improvement:.2%})"
)
else:
p_value = result.get('p_value', 1.0)
conclusions.append(
f"分组 {group} 相对于控制组没有显著差异 (p={p_value:.4f})"
)

return conclusions

# ==================== 配置示例 ====================
def create_sample_config() -> Dict[str, Any]:
"""创建示例配置"""
return {
'experiments_file': 'experiments.json',
'experiments': [
{
'name': 'recommendation_model_v2',
'description': '测试新版推荐模型的效果',
'status': 'running',
'groups': [
{
'name': 'control',
'weight': 0.5,
'model_name': 'recommendation',
'model_version': 'v1',
'description': '当前生产模型'
},
{
'name': 'treatment_v2',
'weight': 0.3,
'model_name': 'recommendation',
'model_version': 'v2',
'description': '新版推荐模型'
},
{
'name': 'treatment_v3',
'weight': 0.2,
'model_name': 'recommendation',
'model_version': 'v3',
'description': '实验性推荐模型'
}
],
'assignment_algorithm': 'hash_based',
'target_users': None, # 所有用户
'sample_rate': 0.1, # 10%流量
'metrics': ['click_rate', 'purchase_rate', 'revenue'],
'hypotheses': '新版推荐模型能提高用户点击率和购买率'
},
{
'name': 'pricing_strategy_test',
'description': '测试不同定价策略',
'status': 'draft',
'groups': [
{
'name': 'standard_pricing',
'weight': 0.33,
'parameters': {
'price_multiplier': 1.0,
'discount_rate': 0.0
},
'description': '标准定价'
},
{
'name': 'discounted_pricing',
'weight': 0.33,
'parameters': {
'price_multiplier': 0.9,
'discount_rate': 0.1
},
'description': '9折定价'
},
{
'name': 'premium_pricing',
'weight': 0.34,
'parameters': {
'price_multiplier': 1.1,
'premium_features': True
},
'description': '高端定价'
}
],
'assignment_algorithm': 'random',
'sample_rate': 0.05, # 5%流量
'metrics': ['conversion_rate', 'average_order_value', 'customer_satisfaction'],
'hypotheses': '适度的折扣能提高转化率而不显著降低客单价'
}
]
}

if __name__ == "__main__":
# 示例用法
config = create_sample_config()

# 创建A/B测试管理器
ab_test_manager = ABTestManager(config)

# 启动实验
ab_test_manager.start_experiment('recommendation_model_v2')

# 模拟用户分配
for i in range(1000):
user_id = f"user_{i}"
assignment = ab_test_manager.assign_experiment(
user_id=user_id,
experiment_name='recommendation_model_v2'
)

if assignment:
print(f"用户 {user_id} 分配到分组 {assignment['group_name']}")

# 模拟用户行为
if random.random() < 0.1: # 10%的点击率
ab_test_manager.log_event(
request_id=assignment['assignment_id'],
user_id=user_id,
experiment_name='recommendation_model_v2',
group_name=assignment['group_name'],
action='click',
metadata={'timestamp': datetime.now().isoformat()}
)

if random.random() < 0.2: # 20%的购买转化率
ab_test_manager.log_event(
request_id=assignment['assignment_id'],
user_id=user_id,
experiment_name='recommendation_model_v2',
group_name=assignment['group_name'],
action='purchase',
metadata={
'amount': random.uniform(10, 100),
'timestamp': datetime.now().isoformat()
}
)

# 获取实验统计
stats = ab_test_manager.get_experiment_stats('recommendation_model_v2')
print(json.dumps(stats, indent=2, ensure_ascii=False))

# 创建监控面板
dashboard = ExperimentDashboard(ab_test_manager)
report = dashboard.generate_report('recommendation_model_v2')
print(json.dumps(report, indent=2, ensure_ascii=False))

 篇幅限制下面就只能给大家展示小册部分内容了。整理了一份核心面试笔记包括了:Java面试、Spring、JVM、MyBatis、Redis、MySQL、并发编程、微服务、Linux、Springboot、SpringCloud、MQ、Kafc

需要全套面试笔记及答案 【点击此处即可/免费获取】​​​

2.2 实验分析与监控

python

复制

下载

"""
实验分析与监控系统
"""
import asyncio
import json
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
import numpy as np
import pandas as pd
from scipy import stats
import plotly.graph_objs as go
import plotly.offline as pyo
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
import seaborn as sns
from sqlalchemy import create_engine, Column, Integer, String, Float, DateTime, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

# ==================== 数据库模型 ====================
Base = declarative_base()

class ExperimentRecord(Base):
"""实验记录表"""
__tablename__ = 'experiment_records'

id = Column(Integer, primary_key=True)
experiment_name = Column(String(255), nullable=False, index=True)
user_id = Column(String(255), nullable=False, index=True)
group_name = Column(String(100), nullable=False)
event_type = Column(String(100), nullable=False)
event_value = Column(Float, nullable=True)
event_metadata = Column(Text, nullable=True) # JSON格式
timestamp = Column(DateTime, nullable=False, index=True)
created_at = Column(DateTime, default=datetime.now)

class ExperimentAnalysis:
"""实验分析器"""

def __init__(self, db_url: str = "sqlite:///experiments.db"):
"""初始化分析器"""
self.engine = create_engine(db_url)
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)

# 统计测试配置
self.alpha = 0.05 # 显著性水平
self.min_sample_size = 30 # 最小样本量
self.power = 0.8 # 统计功效

def record_event(self, event_data: Dict[str, Any]):
"""记录事件到数据库"""
session = self.Session()
try:
record = ExperimentRecord(
experiment_name=event_data['experiment_name'],
user_id=event_data['user_id'],
group_name=event_data['group_name'],
event_type=event_data['event_type'],
event_value=event_data.get('event_value'),
event_metadata=json.dumps(event_data.get('metadata', {})),
timestamp=event_data.get('timestamp', datetime.now())
)
session.add(record)
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()

async def analyze_experiment(
self,
experiment_name: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> Dict[str, Any]:
"""分析实验数据"""
session = self.Session()

try:
# 构建查询
query = session.query(ExperimentRecord).filter(
ExperimentRecord.experiment_name == experiment_name
)

if start_date:
query = query.filter(ExperimentRecord.timestamp >= start_date)
if end_date:
query = query.filter(ExperimentRecord.timestamp <= end_date)

# 获取数据
records = query.all()

if not records:
return {"error": "No data found for experiment"}

# 转换为DataFrame
data = []
for record in records:
data.append({
'user_id': record.user_id,
'group_name': record.group_name,
'event_type': record.event_type,
'event_value': record.event_value,
'timestamp': record.timestamp,
'metadata': json.loads(record.event_metadata) if record.event_metadata else {}
})

df = pd.DataFrame(data)

# 分析不同分组的表现
analysis_result = {
'experiment_name': experiment_name,
'analysis_date': datetime.now().isoformat(),
'time_period': {
'start': start_date.isoformat() if start_date else df['timestamp'].min().isoformat(),
'end': end_date.isoformat() if end_date else df['timestamp'].max().isoformat()
},
'summary': self._generate_summary(df),
'group_comparison': await self._compare_groups(df),
'time_series_analysis': self._analyze_time_series(df),
'power_analysis': self._calculate_power_analysis(df),
'sensitivity_analysis': self._sensitivity_analysis(df),
'visualizations': await self._generate_visualizations(df, experiment_name)
}

return analysis_result

finally:
session.close()

def _generate_summary(self, df: pd.DataFrame) -> Dict[str, Any]:
"""生成数据摘要"""
total_users = df['user_id'].nunique()
total_events = len(df)

summary = {
'total_users': int(total_users),
'total_events': int(total_events),
'events_per_user': float(total_events / total_users) if total_users > 0 else 0,
'groups': {}
}

# 按分组统计
for group_name, group_df in df.groupby('group_name'):
group_users = group_df['user_id'].nunique()
group_events = len(group_df)

# 事件类型统计
event_counts = group_df['event_type'].value_counts().to_dict()

# 数值型事件的统计
numeric_events = group_df[group_df['event_value'].notnull()]
event_stats = {}
if not numeric_events.empty:
for event_type, event_df in numeric_events.groupby('event_type'):
values = event_df['event_value'].astype(float)
event_stats[event_type] = {
'count': int(len(values)),
'mean': float(values.mean()),
'std': float(values.std()),
'min': float(values.min()),
'max': float(values.max()),
'median': float(values.median())
}

summary['groups'][group_name] = {
'user_count': int(group_users),
'event_count': int(group_events),
'event_distribution': event_counts,
'event_statistics': event_stats
}

return summary

async def _compare_groups(self, df: pd.DataFrame) -> Dict[str, Any]:
"""比较不同分组的表现"""
# 识别控制组(假设第一个分组是控制组)
groups = df['group_name'].unique()
if len(groups) < 2:
return {"error": "Need at least 2 groups for comparison"}

control_group = groups[0]
treatment_groups = groups[1:]

comparisons = {}

for treatment_group in treatment_groups:
comparison = await self._compare_two_groups(
df, control_group, treatment_group
)
comparisons[treatment_group] = comparison

return {
'control_group': control_group,
'treatment_groups': list(treatment_groups),
'comparisons': comparisons,
'overall_recommendation': self._generate_overall_recommendation(comparisons)
}

async def _compare_two_groups(
self,
df: pd.DataFrame,
control_group: str,
treatment_group: str
) -> Dict[str, Any]:
"""比较两个分组"""
control_df = df[df['group_name'] == control_group]
treatment_df = df[df['group_name'] == treatment_group]

comparison = {
'control_group': control_group,
'treatment_group': treatment_group,
'user_counts': {
'control': int(control_df['user_id'].nunique()),
'treatment': int(treatment_df['user_id'].nunique())
},
'event_comparisons': {}
}

# 比较每种事件类型
all_event_types = set(control_df['event_type'].unique()) | set(treatment_df['event_type'].unique())

for event_type in all_event_types:
control_events = control_df[control_df['event_type'] == event_type]
treatment_events = treatment_df[treatment_df['event_type'] == event_type]

# 计算转化率
control_users = control_df['user_id'].nunique()
treatment_users = treatment_df['user_id'].nunique()

control_conversions = control_events['user_id'].nunique() if not control_events.empty else 0
treatment_conversions = treatment_events['user_id'].nunique() if not treatment_events.empty else 0

control_rate = control_conversions / control_users if control_users > 0 else 0
treatment_rate = treatment_conversions / treatment_users if treatment_users > 0 else 0

# 计算相对提升
relative_improvement = (treatment_rate – control_rate) / control_rate if control_rate > 0 else 0

# 显著性检验(比例检验)
p_value = 1.0
if control_conversions > 0 and treatment_conversions > 0:
from statsmodels.stats.proportion import proportions_ztest

count = [treatment_conversions, control_conversions]
nobs = [treatment_users, control_users]

try:
z_stat, p_value = proportions_ztest(count, nobs, alternative='larger')
except:
p_value = 1.0

# 数值型事件的比较
value_comparison = None
control_values = control_events['event_value'].dropna()
treatment_values = treatment_events['event_value'].dropna()

if not control_values.empty and not treatment_values.empty:
# t检验
t_stat, t_p_value = stats.ttest_ind(
control_values.astype(float),
treatment_values.astype(float),
equal_var=False
)

# 计算效应大小
control_mean = control_values.mean()
treatment_mean = treatment_values.mean()
control_std = control_values.std()
treatment_std = treatment_values.std()

pooled_std = np.sqrt((control_std**2 + treatment_std**2) / 2)
cohens_d = (treatment_mean – control_mean) / pooled_std if pooled_std > 0 else 0

value_comparison = {
'control_mean': float(control_mean),
'treatment_mean': float(treatment_mean),
'mean_difference': float(treatment_mean – control_mean),
't_statistic': float(t_stat),
'p_value': float(t_p_value),
'cohens_d': float(cohens_d),
'significant': t_p_value < self.alpha
}

comparison['event_comparisons'][event_type] = {
'conversion_rates': {
'control': float(control_rate),
'treatment': float(treatment_rate)
},
'absolute_difference': float(treatment_rate – control_rate),
'relative_improvement': float(relative_improvement),
'p_value': float(p_value),
'significant': p_value < self.alpha,
'value_comparison': value_comparison
}

return comparison

def _analyze_time_series(self, df: pd.DataFrame) -> Dict[str, Any]:
"""时间序列分析"""
df = df.copy()
df['date'] = df['timestamp'].dt.date

# 按日期和分组统计
daily_stats = df.groupby(['date', 'group_name']).agg({
'user_id': 'nunique',
'event_type': 'count'
}).reset_index()

daily_stats.columns = ['date', 'group_name', 'daily_users', 'daily_events']

# 计算累积统计
daily_stats['cumulative_users'] = daily_stats.groupby('group_name')['daily_users'].cumsum()
daily_stats['cumulative_events'] = daily_stats.groupby('group_name')['daily_events'].cumsum()

# 计算每日转化率
daily_stats['daily_conversion_rate'] = daily_stats['daily_events'] / daily_stats['daily_users']
daily_stats['cumulative_conversion_rate'] = daily_stats['cumulative_events'] / daily_stats['cumulative_users']

# 转换为字典格式
time_series_data = {}
for group_name, group_df in daily_stats.groupby('group_name'):
time_series_data[group_name] = {
'dates': [d.isoformat() for d in group_df['date']],
'daily_users': group_df['daily_users'].tolist(),
'daily_events': group_df['daily_events'].tolist(),
'daily_conversion_rates': group_df['daily_conversion_rate'].tolist(),
'cumulative_conversion_rates': group_df['cumulative_conversion_rate'].tolist()
}

return {
'time_series_data': time_series_data,
'trend_analysis': self._analyze_trends(daily_stats)
}

def _analyze_trends(self, daily_stats: pd.DataFrame) -> Dict[str, Any]:
"""分析趋势"""
trends = {}

for group_name, group_df in daily_stats.groupby('group_name'):
# 计算每日转化率的趋势
dates = pd.to_datetime(group_df['date'])
conversion_rates = group_df['daily_conversion_rate'].fillna(0)

if len(conversion_rates) > 1:
# 线性趋势
x = np.arange(len(conversion_rates))
slope, intercept = np.polyfit(x, conversion_rates, 1)

# 计算趋势强度(R²)
y_pred = slope * x + intercept
ss_res = np.sum((conversion_rates – y_pred) ** 2)
ss_tot = np.sum((conversion_rates – np.mean(conversion_rates)) ** 2)
r_squared = 1 – (ss_res / ss_tot) if ss_tot > 0 else 0

trends[group_name] = {
'slope': float(slope),
'intercept': float(intercept),
'r_squared': float(r_squared),
'trend_direction': 'increasing' if slope > 0 else 'decreasing',
'trend_strength': 'strong' if abs(slope) > 0.01 else 'weak'
}
else:
trends[group_name] = {
'error': 'Insufficient data for trend analysis'
}

return trends

def _calculate_power_analysis(self, df: pd.DataFrame) -> Dict[str, Any]:
"""计算统计功效分析"""
from statsmodels.stats.power import TTestIndPower, NormalIndPower

power_analysis = {}

for group_name, group_df in df.groupby('group_name'):
# 计算MDE(最小可检测效应)
users = group_df['user_id'].nunique()

if users >= 2:
# 计算当前效应的变异性
conversion_rates = []
for event_type, event_df in group_df.groupby('event_type'):
event_users = event_df['user_id'].nunique()
conversion_rate = event_users / users if users > 0 else 0
conversion_rates.append(conversion_rate)

if conversion_rates:
effect_size = np.std(conversion_rates)

# 计算需要的样本量
power_analysis_obj = TTestIndPower()
required_n = power_analysis_obj.solve_power(
effect_size=effect_size,
power=self.power,
alpha=self.alpha,
ratio=1.0
)

power_analysis[group_name] = {
'current_sample_size': int(users),
'effect_size_variability': float(effect_size),
'required_sample_size': float(required_n) if not np.isnan(required_n) else None,
'sufficient_power': users >= required_n if required_n else False
}
else:
power_analysis[group_name] = {
'error': 'No conversion data available'
}
else:
power_analysis[group_name] = {
'error': 'Insufficient users for power analysis'
}

return power_analysis

def _sensitivity_analysis(self, df: pd.DataFrame) -> Dict[str, Any]:
"""敏感性分析"""
sensitivity = {}

# 分析不同时间段的表现
if len(df) > 0:
df = df.copy()
df['week'] = df['timestamp'].dt.isocalendar().week

weekly_analysis = {}
for week, week_df in df.groupby('week'):
week_summary = self._generate_summary(week_df)
weekly_analysis[f'week_{week}'] = week_summary

sensitivity['weekly_analysis'] = weekly_analysis

# 分析不同用户子集
user_ids = df['user_id'].unique()
if len(user_ids) > 100:
# 随机抽样分析
np.random.seed(42)
sample_sizes = [0.5, 0.7, 0.9] # 不同采样比例

sampling_analysis = {}
for sample_size in sample_sizes:
sample_users = np.random.choice(
user_ids,
size=int(len(user_ids) * sample_size),
replace=False
)
sample_df = df[df['user_id'].isin(sample_users)]
sample_summary = self._generate_summary(sample_df)
sampling_analysis[f'sample_{int(sample_size*100)}_percent'] = sample_summary

sensitivity['sampling_analysis'] = sampling_analysis

return sensitivity

async def _generate_visualizations(
self,
df: pd.DataFrame,
experiment_name: str
) -> Dict[str, str]:
"""生成可视化图表"""
visualizations = {}

try:
# 1. 转化率对比图
fig1 = self._create_conversion_rate_chart(df)
visualizations['conversion_rate_chart'] = fig1.to_html(full_html=False)

# 2. 时间序列图
fig2 = self._create_time_series_chart(df)
visualizations['time_series_chart'] = fig2.to_html(full_html=False)

# 3. 分布对比图
fig3 = self._create_distribution_chart(df)
visualizations['distribution_chart'] = fig3.to_html(full_html=False)

# 4. 累积效果图
fig4 = self._create_cumulative_effect_chart(df)
visualizations['cumulative_effect_chart'] = fig4.to_html(full_html=False)

except Exception as e:
visualizations['error'] = str(e)

return visualizations

def _create_conversion_rate_chart(self, df: pd.DataFrame) -> go.Figure:
"""创建转化率对比图"""
conversion_rates = []
groups = []

for group_name, group_df in df.groupby('group_name'):
users = group_df['user_id'].nunique()
conversions = group_df[group_df['event_type'] == 'conversion']
conversion_count = conversions['user_id'].nunique() if not conversions.empty else 0

rate = conversion_count / users if users > 0 else 0
conversion_rates.append(rate * 100) # 转换为百分比
groups.append(group_name)

fig = go.Figure(data=[
go.Bar(
x=groups,
y=conversion_rates,
text=[f'{rate:.2f}%' for rate in conversion_rates],
textposition='auto',
marker_color='steelblue'
)
])

fig.update_layout(
title='Conversion Rates by Group',
xaxis_title='Group',
yaxis_title='Conversion Rate (%)',
template='plotly_white'
)

return fig

def _create_time_series_chart(self, df: pd.DataFrame) -> go.Figure:
"""创建时间序列图"""
df = df.copy()
df['date'] = df['timestamp'].dt.date

fig = go.Figure()

for group_name, group_df in df.groupby('group_name'):
daily_stats = group_df.groupby('date').agg({
'user_id': 'nunique',
'event_type': 'count'
}).reset_index()

if not daily_stats.empty:
daily_stats['conversion_rate'] = (
daily_stats['event_type'] / daily_stats['user_id']
) * 100

fig.add_trace(go.Scatter(
x=daily_stats['date'],
y=daily_stats['conversion_rate'],
mode='lines+markers',
name=group_name,
line=dict(width=2)
))

fig.update_layout(
title='Daily Conversion Rate Trends',
xaxis_title='Date',
yaxis_title='Conversion Rate (%)',
template='plotly_white',
hovermode='x unified'
)

return fig

def _create_distribution_chart(self, df: pd.DataFrame) -> go.Figure:
"""创建分布对比图"""
# 提取数值型事件
numeric_df = df[df['event_value'].notnull()].copy()

if numeric_df.empty:
# 如果没有数值型事件,创建空的图表
fig = go.Figure()
fig.update_layout(
title='No numeric data available for distribution analysis',
template='plotly_white'
)
return fig

# 获取主要的事件类型
main_event_type = numeric_df['event_type'].value_counts().index[0]
event_df = numeric_df[numeric_df['event_type'] == main_event_type]

fig = go.Figure()

for group_name, group_df in event_df.groupby('group_name'):
fig.add_trace(go.Violin(
y=group_df['event_value'],
name=group_name,
box_visible=True,
meanline_visible=True
))

fig.update_layout(
title=f'Distribution of {main_event_type} by Group',
yaxis_title='Value',
template='plotly_white'
)

return fig

def _create_cumulative_effect_chart(self, df: pd.DataFrame) -> go.Figure:
"""创建累积效果图"""
df = df.copy()
df['date'] = df['timestamp'].dt.date

# 按日期和分组计算累积转化率
cumulative_data = []

for group_name, group_df in df.groupby('group_name'):
group_df = group_df.sort_values('date')

unique_dates = group_df['date'].unique()
cumulative_users = 0
cumulative_conversions = 0

for date in unique_dates:
date_df = group_df[group_df['date'] == date]
daily_users = date_df['user_id'].nunique()
daily_conversions = len(date_df[date_df['event_type'] == 'conversion'])

cumulative_users += daily_users
cumulative_conversions += daily_conversions

cumulative_rate = (cumulative_conversions / cumulative_users * 100
if cumulative_users > 0 else 0)

cumulative_data.append({
'date': date,
'group': group_name,
'cumulative_rate': cumulative_rate
})

cumulative_df = pd.DataFrame(cumulative_data)

fig = go.Figure()

for group_name, group_df in cumulative_df.groupby('group'):
fig.add_trace(go.Scatter(
x=group_df['date'],
y=group_df['cumulative_rate'],
mode='lines',
name=group_name,
line=dict(width=3)
))

fig.update_layout(
title='Cumulative Conversion Rate Over Time',
xaxis_title='Date',
yaxis_title='Cumulative Conversion Rate (%)',
template='plotly_white',
hovermode='x unified'
)

return fig

def _generate_overall_recommendation(
self,
comparisons: Dict[str, Any]
) -> Dict[str, Any]:
"""生成总体推荐"""
if not comparisons:
return {"decision": "no_data", "reason": "No comparison data available"}

# 检查是否有显著提升的分组
best_group = None
best_improvement = 0
best_p_value = 1.0

for treatment_group, comparison in comparisons.items():
event_comparisons = comparison.get('event_comparisons', {})

for event_type, event_comp in event_comparisons.items():
if event_comp.get('significant', False):
improvement = event_comp.get('relative_improvement', 0)
p_value = event_comp.get('p_value', 1.0)

if improvement > best_improvement:
best_improvement = improvement
best_p_value = p_value
best_group = treatment_group

if best_group and best_improvement > 0:
return {
"decision": "implement",
"recommended_group": best_group,
"expected_improvement": f"{best_improvement:.2%}",
"confidence_level": f"{(1 – best_p_value):.2%}",
"next_steps": [
"Roll out to 100% of traffic",
"Monitor for any negative side effects",
"Update documentation"
]
}
else:
# 检查是否有显著变差的分组
worst_group = None
worst_deterioration = 0

for treatment_group, comparison in comparisons.items():
event_comparisons = comparison.get('event_comparisons', {})

for event_type, event_comp in event_comparisons.items():
if event_comp.get('significant', False):
improvement = event_comp.get('relative_improvement', 0)
if improvement < worst_deterioration:
worst_deterioration = improvement
worst_group = treatment_group

if worst_group and worst_deterioration < -0.05: # 如果变差超过5%
return {
"decision": "reject",
"rejected_group": worst_group,
"deterioration": f"{abs(worst_deterioration):.2%}",
"reason": f"Group {worst_group} performed significantly worse than control",
"next_steps": [
"Stop traffic to this group",
"Analyze why performance was worse",
"Consider alternative approaches"
]
}
else:
return {
"decision": "continue_testing",
"reason": "No significant difference detected",
"next_steps": [
"Increase sample size if possible",
"Extend testing duration",
"Consider adjusting experiment parameters"
]
}

# ==================== 实时监控器 ====================
class ExperimentMonitor:
"""实验实时监控器"""

def __init__(self, ab_test_manager: ABTestManager, analysis: ExperimentAnalysis):
self.ab_test_manager = ab_test_manager
self.analysis = analysis
self.monitoring_tasks = {}
self.alert_rules = {}
self.logger = logging.getLogger("ExperimentMonitor")

# 默认报警规则
self.default_alert_rules = {
'sample_size': {
'min_users': 100,
'warning_threshold': 50
},
'conversion_rate': {
'min_difference': 0.05, # 5%差异
'significance_level': 0.05
},
'safety_metrics': {
'max_deterioration': -0.1, # 最大允许的下降
'check_frequency': 'hourly'
}
}

async def start_monitoring(self, experiment_name: str):
"""开始监控实验"""
if experiment_name in self.monitoring_tasks:
self.logger.warning(f"Experiment {experiment_name} is already being monitored")
return

# 创建监控任务
task = asyncio.create_task(self._monitor_experiment(experiment_name))
self.monitoring_tasks[experiment_name] = task

self.logger.info(f"Started monitoring experiment: {experiment_name}")

async def stop_monitoring(self, experiment_name: str):
"""停止监控实验"""
if experiment_name not in self.monitoring_tasks:
return

task = self.monitoring_tasks[experiment_name]
task.cancel()

try:
await task
except asyncio.CancelledError:
pass

del self.monitoring_tasks[experiment_name]
self.logger.info(f"Stopped monitoring experiment: {experiment_name}")

async def _monitor_experiment(self, experiment_name: str):
"""监控实验主循环"""
try:
while True:
# 分析实验数据
analysis_result = await self.analysis.analyze_experiment(
experiment_name,
start_date=datetime.now() – timedelta(hours=24) # 最近24小时
)

if 'error' not in analysis_result:
# 检查报警条件
alerts = await self._check_alerts(analysis_result)

if alerts:
await self._send_alerts(experiment_name, alerts)

# 检查是否达到停止条件
stop_recommended = await self._check_stop_conditions(analysis_result)

if stop_recommended:
self.logger.info(f"Stop condition met for experiment: {experiment_name}")
await self.ab_test_manager.stop_experiment(experiment_name)
break

# 等待一段时间后再次检查
await asyncio.sleep(3600) # 每小时检查一次

except asyncio.CancelledError:
raise
except Exception as e:
self.logger.error(f"Error monitoring experiment {experiment_name}: {e}")

async def _check_alerts(self, analysis_result: Dict[str, Any]) -> List[Dict[str, Any]]:
"""检查报警条件"""
alerts = []

# 检查样本量
summary = analysis_result.get('summary', {})
for group_name, group_info in summary.get('groups', {}).items():
user_count = group_info.get('user_count', 0)

if user_count < self.default_alert_rules['sample_size']['warning_threshold']:
alerts.append({
'type': 'warning',
'code': 'LOW_SAMPLE_SIZE',
'group': group_name,
'message': f'Group {group_name} has only {user_count} users',
'severity': 'low'
})
elif user_count < self.default_alert_rules['sample_size']['min_users']:
alerts.append({
'type': 'warning',
'code': 'MIN_SAMPLE_NOT_REACHED',
'group': group_name,
'message': f'Group {group_name} has not reached minimum sample size',
'severity': 'medium'
})

# 检查安全性指标
group_comparison = analysis_result.get('group_comparison', {})
comparisons = group_comparison.get('comparisons', {})

for treatment_group, comparison in comparisons.items():
event_comparisons = comparison.get('event_comparisons', {})

for event_type, event_comp in event_comparisons.items():
relative_improvement = event_comp.get('relative_improvement', 0)

if relative_improvement < self.default_alert_rules['safety_metrics']['max_deterioration']:
alerts.append({
'type': 'critical',
'code': 'SAFETY_THRESHOLD_BREACHED',
'group': treatment_group,
'event_type': event_type,
'message': f'Group {treatment_group} shows {relative_improvement:.2%} deterioration in {event_type}',
'severity': 'high'
})

return alerts

async def _send_alerts(self, experiment_name: str, alerts: List[Dict[str, Any]]):
"""发送报警"""
for alert in alerts:
# 这里可以实现报警发送逻辑
# 例如:发送到Slack、邮件、短信等
alert_message = (
f"🚨 Experiment Alert: {experiment_name}\\n"
f"Type: {alert['type']}\\n"
f"Code: {alert['code']}\\n"
f"Message: {alert['message']}\\n"
f"Severity: {alert['severity']}"
)

self.logger.warning(alert_message)

# 可以根据严重程度采取不同行动
if alert['severity'] == 'high':
# 紧急报警:可能需要立即停止实验
await self._send_urgent_alert(alert_message)
elif alert['severity'] == 'medium':
# 中等报警:发送到监控频道
await self._send_monitoring_alert(alert_message)
else:
# 低级别报警:记录日志
self.logger.info(f"Low severity alert: {alert_message}")

async def _send_urgent_alert(self, message: str):
"""发送紧急报警"""
# 实现紧急报警逻辑
pass

async def _send_monitoring_alert(self, message: str):
"""发送监控报警"""
# 实现监控报警逻辑
pass

async def _check_stop_conditions(self, analysis_result: Dict[str, Any]) -> bool:
"""检查停止条件"""
group_comparison = analysis_result.get('group_comparison', {})
recommendation = group_comparison.get('overall_recommendation', {})

decision = recommendation.get('decision', '')

# 如果推荐实施或拒绝,建议停止实验
if decision in ['implement', 'reject']:
return True

# 检查是否达到最小样本量
summary = analysis_result.get('summary', {})
min_users_met = True

for group_info in summary.get('groups', {}).values():
if group_info.get('user_count', 0) < self.default_alert_rules['sample_size']['min_users']:
min_users_met = False
break

# 如果已经达到最小样本量但仍然没有显著差异,可以考虑停止
if min_users_met and decision == 'continue_testing':
# 可以添加更多停止条件
pass

return False

def set_alert_rules(self, experiment_name: str, rules: Dict[str, Any]):
"""设置报警规则"""
self.alert_rules[experiment_name] = rules

def get_monitoring_status(self) -> Dict[str, Any]:
"""获取监控状态"""
status = {
'timestamp': datetime.now().isoformat(),
'monitoring_experiments': list(self.monitoring_tasks.keys()),
'total_monitored': len(self.monitoring_tasks)
}

return status

# ==================== 自动化报告系统 ====================
class AutomatedReporting:
"""自动化报告系统"""

def __init__(self, analysis: ExperimentAnalysis):
self.analysis = analysis
self.report_templates = self._load_report_templates()
self.scheduled_reports = {}

def _load_report_templates(self) -> Dict[str, str]:
"""加载报告模板"""
templates = {
'executive': """
# Experiment Report: {experiment_name}

## Executive Summary
{executive_summary}

## Key Findings
{key_findings}

## Recommendations
{recommendations}

## Next Steps
{next_steps}
""",

'technical': """
# Technical Analysis Report: {experiment_name}

## Methodology
{methodology}

## Statistical Analysis
{statistical_analysis}

## Results
{results}

## Appendix: Detailed Metrics
{detailed_metrics}
""",

'dashboard': """
# Experiment Dashboard: {experiment_name}

## Overview
{overview}

## Real-time Metrics
{realtime_metrics}

## Performance Charts
{performance_charts}

## Alerts
{alerts}
"""
}

return templates

async def generate_report(
self,
experiment_name: str,
report_type: str = 'executive',
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> Dict[str, Any]:
"""生成报告"""
if report_type not in self.report_templates:
raise ValueError(f"Unknown report type: {report_type}")

# 分析实验数据
analysis_result = await self.analysis.analyze_experiment(
experiment_name, start_date, end_date
)

if 'error' in analysis_result:
return analysis_result

# 根据报告类型格式化数据
formatted_data = await self._format_report_data(
analysis_result, report_type
)

# 填充模板
template = self.report_templates[report_type]
report_content = template.format(**formatted_data)

# 添加可视化图表
if 'visualizations' in analysis_result:
report_content += "\\n\\n## Visualizations\\n"
for viz_name, viz_html in analysis_result['visualizations'].items():
if viz_name != 'error':
report_content += f"\\n### {viz_name.replace('_', ' ').title()}\\n"
report_content += f"{viz_html}\\n"

report = {
'experiment_name': experiment_name,
'report_type': report_type,
'generated_at': datetime.now().isoformat(),
'time_period': {
'start': start_date.isoformat() if start_date else analysis_result['time_period']['start'],
'end': end_date.isoformat() if end_date else analysis_result['time_period']['end']
},
'content': report_content,
'analysis_data': analysis_result
}

return report

async def _format_report_data(
self,
analysis_result: Dict[str, Any],
report_type: str
) -> Dict[str, str]:
"""格式化报告数据"""
formatted_data = {
'experiment_name': analysis_result['experiment_name'],
'executive_summary': self._generate_executive_summary(analysis_result),
'key_findings': self._generate_key_findings(analysis_result),
'recommendations': analysis_result.get('group_comparison', {}).get('overall_recommendation', {}).get('next_steps', ['No recommendations']),
'next_steps': self._generate_next_steps(analysis_result),
'methodology': self._format_methodology(analysis_result),
'statistical_analysis': self._format_statistical_analysis(analysis_result),
'results': self._format_results(analysis_result),
'detailed_metrics': self._format_detailed_metrics(analysis_result),
'overview': self._format_overview(analysis_result),
'realtime_metrics': self._format_realtime_metrics(analysis_result),
'performance_charts': "Performance charts will be embedded here",
'alerts': "No active alerts"
}

return formatted_data

def _generate_executive_summary(self, analysis_result: Dict[str, Any]) -> str:
"""生成执行摘要"""
summary = analysis_result.get('summary', {})
comparison = analysis_result.get('group_comparison', {})

total_users = summary.get('total_users', 0)
recommendation = comparison.get('overall_recommendation', {})

exec_summary = f"""
This experiment involved {total_users} users over the analysis period.

Key outcome: {recommendation.get('decision', 'No decision reached')}.

Primary metric performance varied across groups, with treatment groups showing
{'improvement' if recommendation.get('decision') == 'implement' else 'no significant change'}
compared to the control group.
"""

return exec_summary

def _generate_key_findings(self, analysis_result: Dict[str, Any]) -> str:
"""生成关键发现"""
comparison = analysis_result.get('group_comparison', {})
comparisons = comparison.get('comparisons', {})

findings = []

for treatment_group, comp_data in comparisons.items():
event_comparisons = comp_data.get('event_comparisons', {})

for event_type, event_comp in event_comparisons.items():
if event_comp.get('significant', False):
improvement = event_comp.get('relative_improvement', 0)
p_value = event_comp.get('p_value', 1.0)

findings.append(
f"• Group {treatment_group} showed {improvement:.2%} improvement "
f"in {event_type} (p={p_value:.4f})"
)

if not findings:
findings.append("• No statistically significant differences were found between groups")

return "\\n".join(findings)

def _generate_next_steps(self, analysis_result: Dict[str, Any]) -> str:
"""生成后续步骤"""
recommendation = analysis_result.get('group_comparison', {}).get('overall_recommendation', {})
decision = recommendation.get('decision', '')

if decision == 'implement':
return """
1. Roll out the winning variant to 100% of traffic
2. Monitor key metrics for any negative impact
3. Document the change and update relevant systems
4. Plan follow-up experiments for further optimization
"""
elif decision == 'reject':
return """
1. Stop traffic to the underperforming variant
2. Analyze reasons for poor performance
3. Consider alternative approaches
4. Document learnings for future experiments
"""
else:
return """
1. Continue running the experiment
2. Increase sample size if possible
3. Consider adjusting experiment parameters
4. Set up additional monitoring
"""

def _format_methodology(self, analysis_result: Dict[str, Any]) -> str:
"""格式化方法论"""
methodology = """
## Experimental Design
– Random assignment of users to treatment groups
– Control group vs. treatment group(s) comparison
– Minimum sample size: 100 users per group

## Statistical Methods
– Conversion rates compared using proportion z-tests
– Continuous metrics compared using t-tests
– Significance level: α = 0.05
– Confidence intervals: 95%

## Data Collection
– Real-time event tracking
– User-level attribution
– Time-series analysis
"""

return methodology

def _format_statistical_analysis(self, analysis_result: Dict[str, Any]) -> str:
"""格式化统计分析"""
power_analysis = analysis_result.get('power_analysis', {})
comparison = analysis_result.get('group_comparison', {})

stats_text = "## Statistical Power Analysis\\n\\n"

for group_name, power_info in power_analysis.items():
if 'error' not in power_info:
stats_text += (
f"**{group_name}**:\\n"
f"- Current sample: {power_info.get('current_sample_size', 0)}\\n"
f"- Required sample: {power_info.get('required_sample_size', 'N/A')}\\n"
f"- Sufficient power: {power_info.get('sufficient_power', False)}\\n\\n"
)

stats_text += "## Significance Testing Results\\n\\n"

comparisons = comparison.get('comparisons', {})
for treatment_group, comp_data in comparisons.items():
stats_text += f"### {treatment_group} vs Control\\n"

event_comparisons = comp_data.get('event_comparisons', {})
for event_type, event_comp in event_comparisons.items():
stats_text += (
f"**{event_type}**: "
f"p-value = {event_comp.get('p_value', 1.0):.4f}, "
f"Significant: {event_comp.get('significant', False)}\\n"
)

return stats_text

def _format_results(self, analysis_result: Dict[str, Any]) -> str:
"""格式化结果"""
summary = analysis_result.get('summary', {})

results_text = "## Overall Results\\n\\n"
results_text += f"Total Users: {summary.get('total_users', 0)}\\n"
results_text += f"Total Events: {summary.get('total_events', 0)}\\n"
results_text += f"Events per User: {summary.get('events_per_user', 0):.2f}\\n\\n"

results_text += "## Group Performance\\n\\n"

for group_name, group_info in summary.get('groups', {}).items():
results_text += f"### {group_name}\\n"
results_text += f"- Users: {group_info.get('user_count', 0)}\\n"
results_text += f"- Events: {group_info.get('event_count', 0)}\\n"

# 转化率
event_dist = group_info.get('event_distribution', {})
user_count = group_info.get('user_count', 1)

for event_type, count in event_dist.items():
rate = count / user_count if user_count > 0 else 0
results_text += f"- {event_type}: {count} events ({rate:.2%})\\n"

results_text += "\\n"

return results_text

def _format_detailed_metrics(self, analysis_result: Dict[str, Any]) -> str:
"""格式化详细指标"""
summary = analysis_result.get('summary', {})

metrics_text = "## Detailed Metrics by Group\\n\\n"

for group_name, group_info in summary.get('groups', {}).items():
metrics_text += f"### {group_name}\\n"

event_stats = group_info.get('event_statistics', {})
if event_stats:
for event_type, stats in event_stats.items():
metrics_text += (
f"**{event_type}**:\\n"
f"- Count: {stats.get('count', 0)}\\n"
f"- Mean: {stats.get('mean', 0):.2f}\\n"
f"- Std: {stats.get('std', 0):.2f}\\n"
f"- Min: {stats.get('min', 0):.2f}\\n"
f"- Max: {stats.get('max', 0):.2f}\\n"
f"- Median: {stats.get('median', 0):.2f}\\n\\n"
)
else:
metrics_text += "No numeric event data available\\n\\n"

return metrics_text

def _format_overview(self, analysis_result: Dict[str, Any]) -> str:
"""格式化概览"""
return """
## Experiment Status
– Status: Running
– Start Date: {start_date}
– Duration: {duration_days} days
– Sample Rate: 10%

## Key Metrics
– Primary Metric: Conversion Rate
– Guardrail Metrics: Revenue, User Satisfaction
– Statistical Power: 80%

## Current Allocation
– Control: 50%
– Treatment A: 25%
– Treatment B: 25%
""".format(
start_date=analysis_result.get('time_period', {}).get('start', 'N/A'),
duration_days=(datetime.now() – datetime.fromisoformat(
analysis_result.get('time_period', {}).get('start', datetime.now().isoformat())
)).days
)

def _format_realtime_metrics(self, analysis_result: Dict[str, Any]) -> str:
"""格式化实时指标"""
time_series = analysis_result.get('time_series_analysis', {}).get('time_series_data', {})

if not time_series:
return "No time series data available"

latest_data = {}
for group_name, group_data in time_series.items():
if group_data['dates'] and group_data['daily_conversion_rates']:
latest_idx = -1
latest_data[group_name] = {
'date': group_data['dates'][latest_idx],
'conversion_rate': group_data['daily_conversion_rates'][latest_idx],
'users': group_data['daily_users'][latest_idx]
}

metrics_text = "## Latest Metrics (Last 24 Hours)\\n\\n"

for group_name, data in latest_data.items():
metrics_text += (
f"**{group_name}**:\\n"
f"- Date: {data['date']}\\n"
f"- Conversion Rate: {data['conversion_rate']:.2%}\\n"
f"- Users: {data['users']}\\n\\n"
)

return metrics_text

def schedule_report(
self,
experiment_name: str,
frequency: str,
recipients: List[str]
):
"""安排定期报告"""
schedule_id = f"{experiment_name}_{frequency}_{datetime.now().timestamp()}"

self.scheduled_reports[schedule_id] = {
'experiment_name': experiment_name,
'frequency': frequency,
'recipients': recipients,
'last_sent': None,
'next_scheduled': self._calculate_next_run(frequency)
}

return schedule_id

def _calculate_next_run(self, frequency: str) -> datetime:
"""计算下次运行时间"""
now = datetime.now()

if frequency == 'daily':
return now + timedelta(days=1)
elif frequency == 'weekly':
return now + timedelta(weeks=1)
elif frequency == 'monthly':
# 简化处理:下个月的同一天
if now.month == 12:
return datetime(now.year + 1, 1, now.day)
else:
return datetime(now.year, now.month + 1, now.day)
else:
return now + timedelta(days=1) # 默认每天

async def send_scheduled_reports(self):
"""发送定期报告"""
now = datetime.now()

for schedule_id, schedule in self.scheduled_reports.items():
next_run = schedule['next_scheduled']

if now >= next_run:
# 生成并发送报告
report = await self.generate_report(
schedule['experiment_name'],
report_type='executive'
)

# 发送报告(这里需要实现发送逻辑)
await self._send_report(report, schedule['recipients'])

# 更新计划
schedule['last_sent'] = now
schedule['next_scheduled'] = self._calculate_next_run(schedule['frequency'])

async def _send_report(self, report: Dict[str, Any], recipients: List[str]):
"""发送报告"""
# 这里可以实现邮件发送、Slack通知等逻辑
print(f"Would send report to: {recipients}")
print(f"Report content length: {len(str(report))} characters")

# ==================== 主程序入口 ====================
async def main():
"""主程序"""
# 初始化日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s – %(name)s – %(levelname)s – %(message)s'
)

logger = logging.getLogger(__name__)

try:
# 1. 创建A/B测试管理器
config = create_sample_config()
ab_test_manager = ABTestManager(config)

# 2. 创建实验分析器
experiment_analysis = ExperimentAnalysis()

# 3. 创建实验监控器
experiment_monitor = ExperimentMonitor(ab_test_manager, experiment_analysis)

# 4. 创建自动化报告系统
reporting = AutomatedReporting(experiment_analysis)

# 5. 模拟一些数据
logger.info("Simulating experiment data…")
for i in range(1000):
user_id = f"user_{i}"

# 分配实验
assignment = ab_test_manager.assign_experiment(
user_id=user_id,
experiment_name='recommendation_model_v2'
)

if assignment:
# 记录分配事件
experiment_analysis.record_event({
'experiment_name': 'recommendation_model_v2',
'user_id': user_id,
'group_name': assignment['group_name'],
'event_type': 'assignment',
'timestamp': datetime.now(),
'metadata': {'assignment_id': assignment['assignment_id']}
})

# 模拟用户行为(控制组和实验组表现不同)
if assignment['group_name'] == 'control':
click_prob = 0.08 # 控制组点击率8%
purchase_prob = 0.15 # 点击后的购买率15%
elif assignment['group_name'] == 'treatment_v2':
click_prob = 0.10 # 实验组v2点击率10%
purchase_prob = 0.18 # 点击后的购买率18%
else: # treatment_v3
click_prob = 0.12 # 实验组v3点击率12%
purchase_prob = 0.20 # 点击后的购买率20%

# 模拟点击事件
if random.random() < click_prob:
experiment_analysis.record_event({
'experiment_name': 'recommendation_model_v2',
'user_id': user_id,
'group_name': assignment['group_name'],
'event_type': 'click',
'timestamp': datetime.now(),
'metadata': {'source': 'recommendation'}
})

# 模拟购买事件
if random.random() < purchase_prob:
purchase_amount = random.uniform(10, 100)
experiment_analysis.record_event({
'experiment_name': 'recommendation_model_v2',
'user_id': user_id,
'group_name': assignment['group_name'],
'event_type': 'purchase',
'event_value': purchase_amount,
'timestamp': datetime.now(),
'metadata': {
'amount': purchase_amount,
'items': random.randint(1, 5)
}
})

# 6. 分析实验数据
logger.info("Analyzing experiment data…")
analysis_result = await experiment_analysis.analyze_experiment(
'recommendation_model_v2'
)

# 7. 生成报告
logger.info("Generating report…")
report = await reporting.generate_report(
'recommendation_model_v2',
report_type='executive'
)

# 8. 输出结果
print("\\n" + "="*80)
print("EXPERIMENT ANALYSIS REPORT")
print("="*80)
print(f"Experiment: {report['experiment_name']}")
print(f"Report Type: {report['report_type']}")
print(f"Generated: {report['generated_at']}")
print("\\n" + "="*80)
print("CONTENT")
print("="*80)
print(report['content'][:2000] + "…" if len(report['content']) > 2000 else report['content'])

# 9. 开始监控
logger.info("Starting experiment monitoring…")
await experiment_monitor.start_monitoring('recommendation_model_v2')

# 等待一段时间(模拟监控运行)
await asyncio.sleep(5)

# 10. 检查监控状态
status = experiment_monitor.get_monitoring_status()
print(f"\\nMonitoring Status: {status}")

except Exception as e:
logger.error(f"Error in main: {e}")
raise

if __name__ == "__main__":
asyncio.run(main())

三、生产环境部署与优化

3.1 容器化部署配置

yaml

复制

下载

# docker-compose.yaml
version: '3.8'

services:
# 模型服务API
ml-model-api:
build:
context: .
dockerfile: Dockerfile.api
ports:
– "8000:8000"
environment:
– ENVIRONMENT=production
– LOG_LEVEL=INFO
– MODEL_CONFIG_PATH=/app/config/models.json
– AB_TEST_CONFIG_PATH=/app/config/ab_test.json
– DATABASE_URL=postgresql://user:password@db:5432/ml_models
volumes:
– ./models:/app/models
– ./config:/app/config
– ./logs:/app/logs
depends_on:
– db
– redis
– prometheus
networks:
– ml-network
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
deploy:
replicas: 3
resources:
limits:
cpus: '2'
memory: 4G
reservations:
cpus: '0.5'
memory: 1G

# 特征工程服务
feature-service:
build:
context: .
dockerfile: Dockerfile.feature
environment:
– REDIS_HOST=redis
– FEATURE_STORE_TYPE=redis
– LOG_LEVEL=INFO
volumes:
– ./features:/app/features
depends_on:
– redis
networks:
– ml-network
deploy:
replicas: 2

# A/B测试服务
ab-test-service:
build:
context: .
dockerfile: Dockerfile.abtest
ports:
– "8001:8001"
environment:
– DATABASE_URL=postgresql://user:password@db:5432/ab_test
– REDIS_HOST=redis
– CACHE_TTL=3600
depends_on:
– db
– redis
networks:
– ml-network
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8001/health"]

# 数据库 (PostgreSQL)
db:
image: postgres:14
environment:
– POSTGRES_USER=user
– POSTGRES_PASSWORD=password
– POSTGRES_DB=ml_models
volumes:
– postgres_data:/var/lib/postgresql/data
– ./init.sql:/docker-entrypoint-initdb.d/init.sql
networks:
– ml-network
healthcheck:
test: ["CMD-SHELL", "pg_isready -U user"]
interval: 10s
timeout: 5s
retries: 5

# Redis缓存
redis:
image: redis:7-alpine
command: redis-server –appendonly yes
volumes:
– redis_data:/data
networks:
– ml-network
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 3

# 监控 (Prometheus + Grafana)
prometheus:
image: prom/prometheus
volumes:
– ./prometheus.yml:/etc/prometheus/prometheus.yml
– prometheus_data:/prometheus
ports:
– "9090:9090"
networks:
– ml-network

grafana:
image: grafana/grafana
environment:
– GF_SECURITY_ADMIN_PASSWORD=admin
volumes:
– grafana_data:/var/lib/grafana
– ./grafana/dashboards:/etc/grafana/provisioning/dashboards
ports:
– "3000:3000"
networks:
– ml-network
depends_on:
– prometheus

# 模型训练流水线 (Airflow)
airflow-scheduler:
image: apache/airflow:2.5.0
environment:
– AIRFLOW__CORE__EXECUTOR=CeleryExecutor
– AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=postgresql+psycopg2://airflow:airflow@airflow-db:5432/airflow
– AIRFLOW__CELERY__RESULT_BACKEND=db+postgresql://airflow:airflow@airflow-db:5432/airflow
– AIRFLOW__CELERY__BROKER_URL=redis://:@redis:6379/0
volumes:
– ./airflow/dags:/opt/airflow/dags
– ./airflow/logs:/opt/airflow/logs
depends_on:
– airflow-db
– redis
networks:
– ml-network

# 日志收集 (ELK Stack)
elasticsearch:
image: elasticsearch:8.6.0
environment:
– discovery.type=single-node
– xpack.security.enabled=false
volumes:
– elasticsearch_data:/usr/share/elasticsearch/data
ports:
– "9200:9200"
networks:
– ml-network

logstash:
image: logstash:8.6.0
volumes:
– ./logstash/logstash.conf:/usr/share/logstash/pipeline/logstash.conf
depends_on:
– elasticsearch
networks:
– ml-network

kibana:
image: kibana:8.6.0
environment:
– ELASTICSEARCH_HOSTS=http://elasticsearch:9200
ports:
– "5601:5601"
depends_on:
– elasticsearch
networks:
– ml-network

networks:
ml-network:
driver: bridge

volumes:
postgres_data:
redis_data:
prometheus_data:
grafana_data:
elasticsearch_data:

3.2 Kubernetes部署配置

yaml

复制

下载

# k8s-deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
name: ml-model-api
namespace: ml-production
labels:
app: ml-model-api
component: model-serving
spec:
replicas: 3
selector:
matchLabels:
app: ml-model-api
template:
metadata:
labels:
app: ml-model-api
component: model-serving
annotations:
prometheus.io/scrape: "true"
prometheus.io/port: "8000"
prometheus.io/path: "/metrics"
spec:
serviceAccountName: ml-service-account
containers:
– name: model-api
image: registry.example.com/ml-model-api:v1.2.0
imagePullPolicy: IfNotPresent
ports:
– containerPort: 8000
name: http
env:
– name: ENVIRONMENT
value: "production"
– name: LOG_LEVEL
value: "INFO"
– name: MODEL_CONFIG_PATH
value: "/app/config/models.json"
– name: DATABASE_URL
valueFrom:
secretKeyRef:
name: ml-secrets
key: database-url
– name: REDIS_HOST
value: "redis-master.redis.svc.cluster.local"
– name: PROMETHEUS_PUSH_GATEWAY
value: "prometheus-pushgateway.monitoring.svc.cluster.local:9091"
resources:
requests:
memory: "1Gi"
cpu: "500m"
limits:
memory: "4Gi"
cpu: "2000m"
volumeMounts:
– name: model-storage
mountPath: /app/models
readOnly: true
– name: config-volume
mountPath: /app/config
– name: logs-volume
mountPath: /app/logs
livenessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 30
periodSeconds: 10
timeoutSeconds: 5
failureThreshold: 3
readinessProbe:
httpGet:
path: /health
port: 8000
initialDelaySeconds: 5
periodSeconds: 5
timeoutSeconds: 3
startupProbe:
httpGet:
path: /health
port: 8000
failureThreshold: 30
periodSeconds: 10
volumes:
– name: model-storage
persistentVolumeClaim:
claimName: model-pvc
– name: config-volume
configMap:
name: ml-config
– name: logs-volume
emptyDir: {}
affinity:
podAntiAffinity:
preferredDuringSchedulingIgnoredDuringExecution:
– weight: 100
podAffinityTerm:
labelSelector:
matchExpressions:
– key: app
operator: In
values:
– ml-model-api
topologyKey: kubernetes.io/hostname
nodeSelector:
node-type: model-serving
tolerations:
– key: "model-serving"
operator: "Equal"
value: "true"
effect: "NoSchedule"

# Horizontal Pod Autoscaler
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
name: ml-model-api-hpa
namespace: ml-production
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: ml-model-api
minReplicas: 3
maxReplicas: 10
metrics:
– type: Resource
resource:
name: cpu
target:
type: Utilization
averageUtilization: 70
– type: Resource
resource:
name: memory
target:
type: Utilization
averageUtilization: 80
– type: Pods
pods:
metric:
name: model_requests_per_second
target:
type: AverageValue
averageValue: 1000
behavior:
scaleDown:
stabilizationWindowSeconds: 300
policies:
– type: Percent
value: 10
periodSeconds: 60
scaleUp:
stabilizationWindowSeconds: 60
policies:
– type: Percent
value: 100
periodSeconds: 60

# Service
apiVersion: v1
kind: Service
metadata:
name: ml-model-api
namespace: ml-production
annotations:
prometheus.io/scrape: "true"
prometheus.io/port: "8000"
spec:
selector:
app: ml-model-api
ports:
– name: http
port: 8000
targetPort: 8000
type: ClusterIP

# Ingress
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: ml-model-api-ingress
namespace: ml-production
annotations:
nginx.ingress.kubernetes.io/rewrite-target: /
nginx.ingress.kubernetes.io/ssl-redirect: "true"
nginx.ingress.kubernetes.io/proxy-body-size: "50m"
nginx.ingress.kubernetes.io/proxy-read-timeout: "300"
nginx.ingress.kubernetes.io/proxy-send-timeout: "300"
spec:
ingressClassName: nginx
tls:
– hosts:
– ml-api.example.com
secretName: ml-api-tls
rules:
– host: ml-api.example.com
http:
paths:
– path: /
pathType: Prefix
backend:
service:
name: ml-model-api
port:
number: 8000
– path: /metrics
pathType: Prefix
backend:
service:
name: ml-model-api
port:
number: 8000

# ConfigMap
apiVersion: v1
kind: ConfigMap
metadata:
name: ml-config
namespace: ml-production
data:
models.json: |
{
"models": [
{
"name": "recommendation",
"version": "v1",
"framework": "tensorflow",
"model_path": "/app/models/recommendation/v1/model.h5",
"metadata_path": "/app/models/recommendation/v1/metadata.json",
"min_memory_mb": 512,
"max_batch_size": 100,
"feature_pipeline": "recommendation_features"
},
{
"name": "fraud_detection",
"version": "v2",
"framework": "pytorch",
"model_path": "/app/models/fraud/v2/model.pt",
"metadata_path": "/app/models/fraud/v2/metadata.json",
"min_memory_mb": 1024,
"max_batch_size": 50
}
],
"feature_pipelines": {
"recommendation_features": {
"pipeline": "user_item_features",
"version": "v1"
}
},
"performance": {
"default_timeout_ms": 5000,
"max_concurrent_requests": 100,
"circuit_breaker_threshold": 0.5
}
}

ab_test.json: |
{
"experiments": [
{
"name": "recommendation_model_v2",
"description": "测试新版推荐模型的效果",
"status": "running",
"groups": [
{
"name": "control",
"weight": 0.5,
"model_name": "recommendation",
"model_version": "v1"
},
{
"name": "treatment_v2",
"weight": 0.3,
"model_name": "recommendation",
"model_version": "v2"
},
{
"name": "treatment_v3",
"weight": 0.2,
"model_name": "recommendation",
"model_version": "v3"
}
],
"assignment_algorithm": "hash_based",
"sample_rate": 0.1,
"metrics": ["click_rate", "purchase_rate", "revenue"],
"alert_rules": {
"sample_size": {
"min_users": 1000,
"warning_threshold": 100
},
"safety": {
"max_deterioration": -0.1
}
}
}
]
}

# PersistentVolumeClaim for models
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: model-pvc
namespace: ml-production
spec:
accessModes:
– ReadOnlyMany
resources:
requests:
storage: 100Gi
storageClassName: ssd

# ServiceAccount and RBAC
apiVersion: v1
kind: ServiceAccount
metadata:
name: ml-service-account
namespace: ml-production

apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
name: ml-service-role
namespace: ml-production
rules:
– apiGroups: [""]
resources: ["configmaps"]
verbs: ["get", "list", "watch"]
– apiGroups: [""]
resources: ["pods", "services"]
verbs: ["get", "list"]

apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: ml-service-role-binding
namespace: ml-production
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: Role
name: ml-service-role
subjects:
– kind: ServiceAccount
name: ml-service-account
namespace: ml-production

篇幅限制下面就只能给大家展示小册部分内容了。整理了一份核心面试笔记包括了:Java面试、Spring、JVM、MyBatis、Redis、MySQL、并发编程、微服务、Linux、Springboot、SpringCloud、MQ、Kafc

需要全套面试笔记及答案 【点击此处即可/免费获取】​​​

​​​​​​​​​​​​

3.3 性能优化与监控

python

复制

下载

"""
性能优化与监控系统
"""
import asyncio
import time
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable
import threading
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
import psutil
import gc
from prometheus_client import start_http_server, Summary, Counter, Gauge, Histogram
import prometheus_client
from circuitbreaker import circuit
import backoff
import redis
from functools import wraps, lru_cache
import numpy as np
from contextlib import contextmanager

# ==================== 性能监控装饰器 ====================
def monitor_performance(name: str):
"""性能监控装饰器"""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
start_time = time.time()
try:
# 记录请求开始
REQUEST_COUNTER.labels(function=name).inc()

result = await func(*args, **kwargs)

# 记录成功
REQUEST_DURATION.labels(function=name).observe(time.time() – start_time)
REQUEST_SUCCESS_COUNTER.labels(function=name).inc()

return result
except Exception as e:
# 记录失败
REQUEST_ERROR_COUNTER.labels(function=name, error_type=type(e).__name__).inc()
raise
finally:
# 记录内存使用
MEMORY_USAGE.labels(function=name).set(psutil.Process().memory_info().rss)

@wraps(func)
def sync_wrapper(*args, **kwargs):
start_time = time.time()
try:
REQUEST_COUNTER.labels(function=name).inc()

result = func(*args, **kwargs)

REQUEST_DURATION.labels(function=name).observe(time.time() – start_time)
REQUEST_SUCCESS_COUNTER.labels(function=name).inc()

return result
except Exception as e:
REQUEST_ERROR_COUNTER.labels(function=name, error_type=type(e).__name__).inc()
raise
finally:
MEMORY_USAGE.labels(function=name).set(psutil.Process().memory_info().rss)

if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper

return decorator

# ==================== 断路器模式 ====================
class CircuitBreaker:
"""断路器实现"""

def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: int = 60,
expected_exceptions: tuple = (Exception,)
):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exceptions = expected_exceptions

self.failures = 0
self.state = "CLOSED" # CLOSED, OPEN, HALF_OPEN
self.last_failure_time = None
self.metrics = {
'total_calls': 0,
'successful_calls': 0,
'failed_calls': 0,
'circuit_opened': 0
}

@contextmanager
def protect(self):
"""保护代码块"""
self.metrics['total_calls'] += 1

# 检查断路器状态
if self.state == "OPEN":
if self._can_attempt_recovery():
self.state = "HALF_OPEN"
else:
self.metrics['circuit_opened'] += 1
raise CircuitBreakerOpenError("Circuit breaker is OPEN")

try:
yield
self._on_success()

except self.expected_exceptions as e:
self._on_failure()
raise

def _can_attempt_recovery(self) -> bool:
"""检查是否可以尝试恢复"""
if self.last_failure_time is None:
return True

elapsed = time.time() – self.last_failure_time
return elapsed >= self.recovery_timeout

def _on_success(self):
"""成功时处理"""
self.metrics['successful_calls'] += 1

if self.state == "HALF_OPEN":
self.state = "CLOSED"
self.failures = 0

def _on_failure(self):
"""失败时处理"""
self.metrics['failed_calls'] += 1
self.failures += 1
self.last_failure_time = time.time()

if self.failures >= self.failure_threshold:
self.state = "OPEN"

class CircuitBreakerOpenError(Exception):
"""断路器打开异常"""
pass

# ==================== 缓存优化 ====================
class ModelPredictionCache:
"""模型预测缓存"""

def __init__(self, redis_client=None, ttl: int = 300):
self.redis_client = redis_client
self.ttl = ttl # 缓存时间(秒)
self.local_cache = {}
self.local_cache_ttl = {}
self.hits = 0
self.misses = 0

def get_cache_key(self, model_name: str, model_version: str, features_hash: str) -> str:
"""生成缓存键"""
return f"prediction:{model_name}:{model_version}:{features_hash}"

def get(self, model_name: str, model_version: str, features: Dict) -> Optional[Any]:
"""获取缓存"""
features_hash = self._hash_features(features)
cache_key = self.get_cache_key(model_name, model_version, features_hash)

# 1. 检查本地缓存
if cache_key in self.local_cache:
ttl = self.local_cache_ttl.get(cache_key, 0)
if ttl > time.time():
self.hits += 1
return self.local_cache[cache_key]

# 2. 检查Redis缓存
if self.redis_client:
try:
cached = self.redis_client.get(cache_key)
if cached:
# 解析缓存数据
result = json.loads(cached)

# 更新本地缓存
self.local_cache[cache_key] = result
self.local_cache_ttl[cache_key] = time.time() + min(self.ttl, 30)

self.hits += 1
return result
except Exception as e:
# Redis访问失败,继续执行
pass

self.misses += 1
return None

def set(self, model_name: str, model_version: str, features: Dict, result: Any):
"""设置缓存"""
features_hash = self._hash_features(features)
cache_key = self.get_cache_key(model_name, model_version, features_hash)

# 1. 设置本地缓存
self.local_cache[cache_key] = result
self.local_cache_ttl[cache_key] = time.time() + min(self.ttl, 30)

# 2. 设置Redis缓存
if self.redis_client:
try:
self.redis_client.setex(
cache_key,
self.ttl,
json.dumps(result, default=str)
)
except Exception as e:
# Redis访问失败,忽略
pass

def _hash_features(self, features: Dict) -> str:
"""哈希特征"""
import hashlib
features_str = json.dumps(features, sort_keys=True)
return hashlib.md5(features_str.encode()).hexdigest()

def get_stats(self) -> Dict[str, Any]:
"""获取缓存统计"""
total = self.hits + self.misses
hit_rate = self.hits / total if total > 0 else 0

return {
'hits': self.hits,
'misses': self.misses,
'total': total,
'hit_rate': hit_rate,
'local_cache_size': len(self.local_cache)
}

def clear(self):
"""清空缓存"""
self.local_cache.clear()
self.local_cache_ttl.clear()
self.hits = 0
self.misses = 0

# ==================== 批量处理优化 ====================
class BatchProcessor:
"""批量处理器"""

def __init__(self, max_batch_size: int = 100, max_wait_time: float = 0.1):
self.max_batch_size = max_batch_size
self.max_wait_time = max_wait_time

self.queue = asyncio.Queue()
self.processing = False
self.processing_task = None

# 统计信息
self.total_batches = 0
self.total_items = 0
self.average_batch_size = 0

async def add_item(self, item: Any) -> asyncio.Future:
"""添加项目到批量处理器"""
future = asyncio.Future()
await self.queue.put((item, future))

# 如果没有在处理,启动处理任务
if not self.processing:
self.start_processing()

return future

def start_processing(self):
"""启动处理"""
if self.processing:
return

self.processing = True
self.processing_task = asyncio.create_task(self._process_batches())

async def _process_batches(self):
"""处理批次"""
while self.processing:
try:
# 收集一批项目
batch = []
start_time = time.time()

# 收集直到达到最大批次大小或超时
while len(batch) < self.max_batch_size:
try:
item, future = await asyncio.wait_for(
self.queue.get(),
timeout=self.max_wait_time
)
batch.append((item, future))
except asyncio.TimeoutError:
# 超时,处理当前批次
break

if not batch:
# 队列为空,停止处理
self.processing = False
break

# 处理批次
await self._process_batch(batch)

# 更新统计
self.total_batches += 1
self.total_items += len(batch)
self.average_batch_size = self.total_items / self.total_batches

except Exception as e:
logger.error(f"Batch processing error: {e}")
# 设置所有未完成的future为异常
for _, future in batch:
if not future.done():
future.set_exception(e)

self.processing = False

async def _process_batch(self, batch: List[tuple]):
"""处理单个批次"""
# 提取项目和future
items = [item for item, _ in batch]
futures = [future for _, future in batch]

try:
# 执行批量处理
results = await self.process_items(items)

# 设置future结果
for future, result in zip(futures, results):
if not future.done():
future.set_result(result)

except Exception as e:
# 设置所有future为异常
for future in futures:
if not future.done():
future.set_exception(e)

async def process_items(self, items: List[Any]) -> List[Any]:
"""处理项目(子类重写)"""
raise NotImplementedError

def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
'queue_size': self.queue.qsize(),
'processing': self.processing,
'total_batches': self.total_batches,
'total_items': self.total_items,
'average_batch_size': self.average_batch_size,
'max_batch_size': self.max_batch_size,
'max_wait_time': self.max_wait_time
}

class ModelBatchProcessor(BatchProcessor):
"""模型批量处理器"""

def __init__(self, model, max_batch_size: int = 100, max_wait_time: float = 0.1):
super().__init__(max_batch_size, max_wait_time)
self.model = model

async def process_items(self, items: List[Dict]) -> List[Dict]:
"""批量处理模型预测"""
try:
# 批量预测
predictions = await self.model.batch_predict(items)
return predictions

except Exception as e:
logger.error(f"Batch prediction error: {e}")
raise

# ==================== 内存管理优化 ====================
class MemoryManager:
"""内存管理器"""

def __init__(self, max_memory_usage: float = 0.8):
self.max_memory_usage = max_memory_usage # 最大内存使用率
self.memory_warnings = []
self.last_gc_time = time.time()
self.gc_interval = 300 # 每5分钟强制GC一次

# 监控指标
self.memory_usage_gauge = Gauge('memory_usage_percent', 'Memory usage percentage')
self.gc_count_gauge = Gauge('gc_collections', 'Garbage collection count')

def check_memory_usage(self) -> Dict[str, Any]:
"""检查内存使用情况"""
process = psutil.Process()
memory_info = process.memory_info()

# 获取系统内存信息
system_memory = psutil.virtual_memory()

# 计算使用率
process_usage = memory_info.rss / system_memory.total
system_usage = system_memory.used / system_memory.total

# 更新监控指标
self.memory_usage_gauge.set(system_usage * 100)

result = {
'process_memory_mb': memory_info.rss / 1024 / 1024,
'process_memory_percent': process_usage * 100,
'system_memory_percent': system_usage * 100,
'system_memory_available_mb': system_memory.available / 1024 / 1024,
'memory_warnings': len(self.memory_warnings)
}

# 检查是否需要GC
if system_usage > self.max_memory_usage:
warning = {
'timestamp': datetime.now().isoformat(),
'memory_usage': system_usage,
'threshold': self.max_memory_usage,
'message': 'High memory usage detected'
}
self.memory_warnings.append(warning)

# 强制GC
self.force_gc()

# 限制警告数量
if len(self.memory_warnings) > 100:
self.memory_warnings = self.memory_warnings[-100:]

# 定期GC
if time.time() – self.last_gc_time > self.gc_interval:
self.force_gc()

return result

def force_gc(self):
"""强制垃圾回收"""
gc.collect()
self.last_gc_time = time.time()

# 更新监控指标
self.gc_count_gauge.inc()

def get_memory_stats(self) -> Dict[str, Any]:
"""获取内存统计"""
process = psutil.Process()

return {
'timestamp': datetime.now().isoformat(),
'process': {
'rss_mb': process.memory_info().rss / 1024 / 1024,
'vms_mb': process.memory_info().vms / 1024 / 1024,
'percent': process.memory_percent(),
'threads': process.num_threads()
},
'system': {
'total_mb': psutil.virtual_memory().total / 1024 / 1024,
'available_mb': psutil.virtual_memory().available / 1024 / 1024,
'percent': psutil.virtual_memory().percent,
'swap_percent': psutil.swap_memory().percent if hasattr(psutil, 'swap_memory') else 0
},
'warnings': self.memory_warnings[-10:] if self.memory_warnings else []
}

# ==================== 负载均衡优化 ====================
class LoadBalancer:
"""负载均衡器"""

def __init__(self, endpoints: List[str]):
self.endpoints = endpoints
self.current_index = 0
self.endpoint_stats = {endpoint: {'success': 0, 'failure': 0, 'latency': []} for endpoint in endpoints}
self.lock = threading.Lock()

# 健康检查间隔
self.health_check_interval = 30
self.unhealthy_endpoints = set()

# 启动健康检查
self.health_check_thread = threading.Thread(target=self._health_check_loop, daemon=True)
self.health_check_thread.start()

def get_endpoint(self) -> Optional[str]:
"""获取一个可用的端点"""
with self.lock:
# 过滤掉不健康的端点
healthy_endpoints = [e for e in self.endpoints if e not in self.unhealthy_endpoints]

if not healthy_endpoints:
return None

# 使用轮询算法
endpoint = healthy_endpoints[self.current_index % len(healthy_endpoints)]
self.current_index += 1

return endpoint

def record_success(self, endpoint: str, latency: float):
"""记录成功"""
with self.lock:
if endpoint in self.endpoint_stats:
self.endpoint_stats[endpoint]['success'] += 1
self.endpoint_stats[endpoint]['latency'].append(latency)

# 限制延迟记录数量
if len(self.endpoint_stats[endpoint]['latency']) > 1000:
self.endpoint_stats[endpoint]['latency'] = self.endpoint_stats[endpoint]['latency'][-1000:]

def record_failure(self, endpoint: str):
"""记录失败"""
with self.lock:
if endpoint in self.endpoint_stats:
self.endpoint_stats[endpoint]['failure'] += 1

def _health_check_loop(self):
"""健康检查循环"""
while True:
time.sleep(self.health_check_interval)
self._check_endpoints_health()

def _check_endpoints_health(self):
"""检查端点健康状态"""
for endpoint in self.endpoints:
is_healthy = self._check_endpoint_health(endpoint)

with self.lock:
if is_healthy and endpoint in self.unhealthy_endpoints:
self.unhealthy_endpoints.remove(endpoint)
logger.info(f"Endpoint {endpoint} is now healthy")
elif not is_healthy and endpoint not in self.unhealthy_endpoints:
self.unhealthy_endpoints.add(endpoint)
logger.warning(f"Endpoint {endpoint} is now unhealthy")

def _check_endpoint_health(self, endpoint: str) -> bool:
"""检查单个端点健康状态"""
try:
# 简单的HTTP健康检查
import requests
response = requests.get(f"{endpoint}/health", timeout=5)
return response.status_code == 200
except Exception as e:
return False

def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
with self.lock:
stats = {}

for endpoint, endpoint_stat in self.endpoint_stats.items():
total = endpoint_stat['success'] + endpoint_stat['failure']
success_rate = endpoint_stat['success'] / total if total > 0 else 0

latency_values = endpoint_stat['latency']
avg_latency = sum(latency_values) / len(latency_values) if latency_values else 0

stats[endpoint] = {
'success': endpoint_stat['success'],
'failure': endpoint_stat['failure'],
'total': total,
'success_rate': success_rate,
'avg_latency_ms': avg_latency * 1000,
'is_healthy': endpoint not in self.unhealthy_endpoints
}

return {
'total_endpoints': len(self.endpoints),
'healthy_endpoints': len(self.endpoints) – len(self.unhealthy_endpoints),
'endpoint_stats': stats
}

# ==================== 监控仪表板 ====================
class PerformanceDashboard:
"""性能监控仪表板"""

def __init__(self, port: int = 9091):
self.port = port
self.metrics = {}
self.alerts = []
self.alert_rules = []

# 启动Prometheus HTTP服务器
start_http_server(self.port)

def add_metric(self, name: str, metric_type: str, **kwargs):
"""添加指标"""
if metric_type == 'counter':
self.metrics[name] = Counter(name, kwargs.get('description', ''), kwargs.get('labelnames', []))
elif metric_type == 'gauge':
self.metrics[name] = Gauge(name, kwargs.get('description', ''), kwargs.get('labelnames', []))
elif metric_type == 'histogram':
self.metrics[name] = Histogram(name, kwargs.get('description', ''), kwargs.get('labelnames', []))
elif metric_type == 'summary':
self.metrics[name] = Summary(name, kwargs.get('description', ''), kwargs.get('labelnames', []))

def add_alert_rule(self, name: str, condition: Callable, action: Callable, severity: str = 'warning'):
"""添加报警规则"""
self.alert_rules.append({
'name': name,
'condition': condition,
'action': action,
'severity': severity,
'last_triggered': None
})

async def check_alerts(self):
"""检查报警"""
for rule in self.alert_rules:
try:
if rule['condition']():
# 触发报警
rule['action']()

alert = {
'name': rule['name'],
'severity': rule['severity'],
'timestamp': datetime.now().isoformat(),
'message': f"Alert {rule['name']} triggered"
}

self.alerts.append(alert)
rule['last_triggered'] = datetime.now()

# 限制报警数量
if len(self.alerts) > 1000:
self.alerts = self.alerts[-1000:]

except Exception as e:
logger.error(f"Error checking alert rule {rule['name']}: {e}")

def get_dashboard_data(self) -> Dict[str, Any]:
"""获取仪表板数据"""
# 收集系统指标
cpu_percent = psutil.cpu_percent(interval=1)
memory = psutil.virtual_memory()
disk = psutil.disk_usage('/')
network = psutil.net_io_counters()

# 收集进程指标
process = psutil.Process()
process_memory = process.memory_info()

dashboard_data = {
'timestamp': datetime.now().isoformat(),
'system': {
'cpu_percent': cpu_percent,
'memory_percent': memory.percent,
'memory_available_mb': memory.available / 1024 / 1024,
'disk_percent': disk.percent,
'disk_free_gb': disk.free / 1024 / 1024 / 1024,
'network_bytes_sent': network.bytes_sent,
'network_bytes_recv': network.bytes_recv
},
'process': {
'memory_rss_mb': process_memory.rss / 1024 / 1024,
'memory_percent': process.memory_percent(),
'cpu_percent': process.cpu_percent(),
'threads': process.num_threads(),
'connections': len(process.connections())
},
'alerts': self.alerts[-20:] if self.alerts else [],
'active_alerts': len([a for a in self.alerts if
datetime.fromisoformat(a['timestamp']) >
datetime.now() – timedelta(hours=1)]),
'metrics': self._get_metrics_snapshot()
}

return dashboard_data

def _get_metrics_snapshot(self) -> Dict[str, Any]:
"""获取指标快照"""
snapshot = {}

for name, metric in self.metrics.items():
if hasattr(metric, '_metrics'):
# 获取指标值
metric_data = metric._metrics
snapshot[name] = {}

for labels, value in metric_data.items():
snapshot[name][str(labels)] = value._value if hasattr(value, '_value') else str(value)

return snapshot

async def start_monitoring(self):
"""启动监控"""
# 启动报警检查循环
asyncio.create_task(self._monitoring_loop())

logger.info(f"Performance dashboard started on port {self.port}")

async def _monitoring_loop(self):
"""监控循环"""
while True:
try:
await self.check_alerts()
await asyncio.sleep(60) # 每分钟检查一次
except Exception as e:
logger.error(f"Error in monitoring loop: {e}")
await asyncio.sleep(60)

# ==================== 主程序 ====================
async def main_optimized():
"""优化的主程序"""
# 初始化日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s – %(name)s – %(levelname)s – %(message)s'
)

logger = logging.getLogger(__name__)

try:
logger.info("Starting optimized ML Model Serving system…")

# 1. 初始化性能监控仪表板
dashboard = PerformanceDashboard(port=9091)

# 添加指标
dashboard.add_metric('model_predictions_total', 'counter',
description='Total model predictions',
labelnames=['model_name', 'model_version'])

dashboard.add_metric('prediction_latency_seconds', 'histogram',
description='Prediction latency in seconds',
labelnames=['model_name', 'model_version'])

dashboard.add_metric('active_connections', 'gauge',
description='Active connections')

# 添加报警规则
dashboard.add_alert_rule(
name='high_cpu_usage',
condition=lambda: psutil.cpu_percent() > 80,
action=lambda: logger.warning("High CPU usage detected"),
severity='warning'
)

dashboard.add_alert_rule(
name='high_memory_usage',
condition=lambda: psutil.virtual_memory().percent > 85,
action=lambda: logger.error("High memory usage detected"),
severity='critical'
)

# 2. 初始化内存管理器
memory_manager = MemoryManager(max_memory_usage=0.8)

# 3. 初始化缓存
redis_client = redis.Redis(host='localhost', port=6379, db=0)
prediction_cache = ModelPredictionCache(redis_client, ttl=300)

# 4. 初始化负载均衡器
endpoints = [
"http://model-service-1:8000",
"http://model-service-2:8000",
"http://model-service-3:8000"
]
load_balancer = LoadBalancer(endpoints)

# 5. 初始化断路器
circuit_breaker = CircuitBreaker(
failure_threshold=5,
recovery_timeout=60
)

# 6. 启动监控
await dashboard.start_monitoring()

logger.info("Optimized ML Model Serving system started successfully")

# 保持程序运行
while True:
# 定期检查内存使用
memory_info = memory_manager.check_memory_usage()

# 获取缓存统计
cache_stats = prediction_cache.get_stats()

# 获取负载均衡器统计
lb_stats = load_balancer.get_stats()

# 获取仪表板数据
dashboard_data = dashboard.get_dashboard_data()

# 记录状态
if time.time() % 60 < 1: # 每分钟记录一次
logger.info(f"Memory usage: {memory_info['system_memory_percent']:.1f}%")
logger.info(f"Cache hit rate: {cache_stats['hit_rate']:.2%}")
logger.info(f"Healthy endpoints: {lb_stats['healthy_endpoints']}/{lb_stats['total_endpoints']}")

await asyncio.sleep(1)

except Exception as e:
logger.error(f"Error in optimized system: {e}")
raise

# ==================== 性能测试 ====================
class PerformanceTest:
"""性能测试"""

def __init__(self, api_url: str):
self.api_url = api_url
self.results = []

async def run_concurrent_test(
self,
num_requests: int,
num_concurrent: int,
payload: Dict[str, Any]
) -> Dict[str, Any]:
"""运行并发测试"""
start_time = time.time()

# 创建请求任务
tasks = []
for i in range(num_requests):
task = asyncio.create_task(self._make_request(payload))
tasks.append(task)

# 控制并发数
if len(tasks) >= num_concurrent:
await asyncio.gather(*tasks)
tasks = []

# 等待剩余任务
if tasks:
await asyncio.gather(*tasks)

end_time = time.time()

# 分析结果
successful = sum(1 for r in self.results if r['success'])
failed = len(self.results) – successful

latencies = [r['latency'] for r in self.results if r['success']]

if latencies:
avg_latency = sum(latencies) / len(latencies)
p95_latency = np.percentile(latencies, 95)
p99_latency = np.percentile(latencies, 99)
else:
avg_latency = p95_latency = p99_latency = 0

return {
'total_requests': num_requests,
'concurrent_requests': num_concurrent,
'total_time_seconds': end_time – start_time,
'requests_per_second': num_requests / (end_time – start_time) if (end_time – start_time) > 0 else 0,
'successful_requests': successful,
'failed_requests': failed,
'success_rate': successful / num_requests if num_requests > 0 else 0,
'average_latency_ms': avg_latency * 1000,
'p95_latency_ms': p95_latency * 1000,
'p99_latency_ms': p99_latency * 1000
}

async def _make_request(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""发送单个请求"""
request_start = time.time()

try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.api_url}/predict",
json=payload,
timeout=30
) as response:

if response.status == 200:
result = {
'success': True,
'latency': time.time() – request_start,
'status_code': response.status
}
else:
result = {
'success': False,
'latency': time.time() – request_start,
'status_code': response.status,
'error': await response.text()
}

except Exception as e:
result = {
'success': False,
'latency': time.time() – request_start,
'error': str(e)
}

self.results.append(result)
return result

def run_load_test_scenarios(self):
"""运行负载测试场景"""
scenarios = [
{'name': 'low_load', 'requests': 100, 'concurrent': 10},
{'name': 'medium_load', 'requests': 1000, 'concurrent': 50},
{'name': 'high_load', 'requests': 10000, 'concurrent': 100},
{'name': 'stress_test', 'requests': 50000, 'concurrent': 200}
]

all_results = {}

for scenario in scenarios:
logger.info(f"Running {scenario['name']} test…")

# 运行测试
asyncio.run(self.run_concurrent_test(
scenario['requests'],
scenario['concurrent'],
self._create_test_payload()
))

# 获取结果
result = self.results[-1] if self.results else {}
all_results[scenario['name']] = result

# 重置结果
self.results = []

return all_results

def _create_test_payload(self) -> Dict[str, Any]:
"""创建测试负载"""
return {
"model_name": "recommendation",
"model_version": "v1",
"data": {
"user_id": "test_user",
"features": [0.1, 0.2, 0.3, 0.4, 0.5] * 20 # 100个特征
}
}

def generate_performance_report(self, test_results: Dict[str, Any]) -> str:
"""生成性能报告"""
report = "# Performance Test Report\\n\\n"

for scenario_name, results in test_results.items():
report += f"## {scenario_name.upper()}\\n\\n"

report += f"- Total Requests: {results.get('total_requests', 0)}\\n"
report += f"- Concurrent Requests: {results.get('concurrent_requests', 0)}\\n"
report += f"- Total Time: {results.get('total_time_seconds', 0):.2f} seconds\\n"
report += f"- Requests per Second: {results.get('requests_per_second', 0):.2f}\\n"
report += f"- Success Rate: {results.get('success_rate', 0):.2%}\\n"
report += f"- Average Latency: {results.get('average_latency_ms', 0):.2f} ms\\n"
report += f"- P95 Latency: {results.get('p95_latency_ms', 0):.2f} ms\\n"
report += f"- P99 Latency: {results.get('p99_latency_ms', 0):.2f} ms\\n\\n"

# 添加总结
report += "## Summary\\n\\n"

# 分析瓶颈
bottlenecks = []

for scenario_name, results in test_results.items():
if results.get('success_rate', 1) < 0.95:
bottlenecks.append(f"{scenario_name}: Low success rate ({results.get('success_rate', 0):.2%})")

if results.get('p99_latency_ms', 0) > 1000: # 超过1秒
bottlenecks.append(f"{scenario_name}: High P99 latency ({results.get('p99_latency_ms', 0):.2f} ms)")

if bottlenecks:
report += "### Potential Bottlenecks\\n\\n"
for bottleneck in bottlenecks:
report += f"- {bottleneck}\\n"
report += "\\n"
else:
report += "No significant bottlenecks detected.\\n\\n"

# 建议
report += "### Recommendations\\n\\n"

max_rps = max(r.get('requests_per_second', 0) for r in test_results.values())
if max_rps < 100:
report += "- Consider optimizing model inference performance\\n"
report += "- Implement batching for predictions\\n"
report += "- Add caching for frequent requests\\n"

return report

if __name__ == "__main__":
# 运行优化的系统
asyncio.run(main_optimized())

总结

本系统实现了完整的机器学习模型在线服务与A/B测试架构,包括:

核心特性:

  • 统一模型服务接口:支持多种框架(TensorFlow、PyTorch、Scikit-learn等)

  • 高性能API服务:基于FastAPI的异步API,支持批量预测

  • 完整的A/B测试系统:包含实验设计、用户分配、效果分析等功能

  • 实时监控与分析:提供详细的实验分析和统计检验

  • 生产级部署:支持容器化和Kubernetes部署

  • 性能优化:缓存、批量处理、负载均衡、断路器等

  • 关键优势:

  • 可扩展性:微服务架构,支持水平扩展

  • 可靠性:内置故障恢复和健康检查机制

  • 可观测性:全面的监控和日志系统

  • 易用性:RESTful API接口,易于集成

  • 安全性:支持认证、授权和HTTPS

  • 生产建议:

  • 性能调优:根据实际负载调整缓存策略和批处理参数

  • 监控告警:设置合适的阈值,及时发现并解决问题

  • 版本管理:建立模型版本控制流程

  • 安全加固:实施API网关、限流和审计日志

  • 成本优化:根据使用模式选择合适的云资源

  • 这个系统为机器学习模型的生产部署提供了完整的解决方案,可以帮助团队快速、可靠地将模型投入生产环境。

    赞(0)
    未经允许不得转载:网硕互联帮助中心 » 米哈游Java面试被问:机器学习模型的在线服务和A/B测试
    分享到: 更多 (0)

    评论 抢沙发

    评论前必须登录!