64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.orm import Session
|
|
from typing import Dict
|
|
from ...core.database import get_db
|
|
from ...services.prediction_service import PredictionService
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/train/{lottery_type}")
|
|
def train_prediction_model(
|
|
lottery_type: str,
|
|
periods: int = 100,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""训练预测模型"""
|
|
if lottery_type not in ['ssq', 'dlt']:
|
|
raise HTTPException(status_code=400, detail="Invalid lottery type")
|
|
|
|
service = PredictionService(db)
|
|
return service.train_model(lottery_type, periods)
|
|
|
|
|
|
@router.get("/predict/{lottery_type}")
|
|
def predict_next_numbers(
|
|
lottery_type: str,
|
|
periods: int = 10,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""预测下一期号码"""
|
|
if lottery_type not in ['ssq', 'dlt']:
|
|
raise HTTPException(status_code=400, detail="Invalid lottery type")
|
|
|
|
service = PredictionService(db)
|
|
return service.predict_next_numbers(lottery_type, periods)
|
|
|
|
|
|
@router.get("/pattern/{lottery_type}")
|
|
def get_pattern_prediction(
|
|
lottery_type: str,
|
|
periods: int = 100,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""基于模式的预测"""
|
|
if lottery_type not in ['ssq', 'dlt']:
|
|
raise HTTPException(status_code=400, detail="Invalid lottery type")
|
|
|
|
service = PredictionService(db)
|
|
return service.get_pattern_based_prediction(lottery_type, periods)
|
|
|
|
|
|
@router.get("/ensemble/{lottery_type}")
|
|
def get_ensemble_prediction(
|
|
lottery_type: str,
|
|
periods: int = 100,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""集成预测"""
|
|
if lottery_type not in ['ssq', 'dlt']:
|
|
raise HTTPException(status_code=400, detail="Invalid lottery type")
|
|
|
|
service = PredictionService(db)
|
|
return service.get_ensemble_prediction(lottery_type, periods)
|