536 lines
21 KiB
Python
536 lines
21 KiB
Python
from typing import List, Dict, Tuple, Optional
|
||
import numpy as np
|
||
import pandas as pd
|
||
from sklearn.ensemble import RandomForestRegressor
|
||
from sklearn.preprocessing import StandardScaler
|
||
from sklearn.model_selection import train_test_split
|
||
from collections import defaultdict
|
||
from sqlalchemy.orm import Session
|
||
from ..models.lottery import SSQLottery, DLTLottery
|
||
|
||
|
||
class PredictionService:
|
||
# 类级别的字典来存储所有模型
|
||
_models = {}
|
||
_scalers = {}
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
|
||
def prepare_features(self, lottery_type: str, periods: int = 100) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""准备机器学习特征
|
||
|
||
Args:
|
||
lottery_type: 彩票类型 ('ssq' 或 'dlt')
|
||
periods: 使用期数
|
||
|
||
Returns:
|
||
Tuple: (特征矩阵, 标签矩阵)
|
||
"""
|
||
model = SSQLottery if lottery_type == 'ssq' else DLTLottery
|
||
recent_draws = self.db.query(model).order_by(
|
||
model.open_time.desc()).limit(periods).all()
|
||
|
||
features = []
|
||
labels = []
|
||
|
||
for i in range(len(recent_draws) - 10): # 使用前10期预测下一期
|
||
# 特征:前10期的号码
|
||
feature_row = []
|
||
for j in range(10):
|
||
draw = recent_draws[i + j]
|
||
if lottery_type == 'ssq':
|
||
numbers = [draw.red_ball_1, draw.red_ball_2, draw.red_ball_3,
|
||
draw.red_ball_4, draw.red_ball_5, draw.red_ball_6]
|
||
feature_row.extend(sorted(numbers))
|
||
feature_row.append(draw.blue_ball) # 添加蓝球
|
||
else:
|
||
numbers = [draw.front_ball_1, draw.front_ball_2, draw.front_ball_3,
|
||
draw.front_ball_4, draw.front_ball_5]
|
||
feature_row.extend(sorted(numbers))
|
||
feature_row.extend(
|
||
[draw.back_ball_1, draw.back_ball_2]) # 添加后区号码
|
||
|
||
# 标签:下一期的号码
|
||
next_draw = recent_draws[i + 10]
|
||
if lottery_type == 'ssq':
|
||
label_numbers = [next_draw.red_ball_1, next_draw.red_ball_2, next_draw.red_ball_3,
|
||
next_draw.red_ball_4, next_draw.red_ball_5, next_draw.red_ball_6]
|
||
label_numbers.append(next_draw.blue_ball) # 添加蓝球
|
||
else:
|
||
label_numbers = [next_draw.front_ball_1, next_draw.front_ball_2, next_draw.front_ball_3,
|
||
next_draw.front_ball_4, next_draw.front_ball_5]
|
||
label_numbers.extend(
|
||
[next_draw.back_ball_1, next_draw.back_ball_2]) # 添加后区号码
|
||
|
||
features.append(feature_row)
|
||
# 保持红球排序,蓝球位置不变
|
||
labels.append(sorted(label_numbers[:-1]) + [label_numbers[-1]])
|
||
|
||
return np.array(features), np.array(labels)
|
||
|
||
def train_model(self, lottery_type: str, periods: int = 100) -> Dict:
|
||
"""训练预测模型
|
||
|
||
Args:
|
||
lottery_type: 彩票类型 ('ssq' 或 'dlt')
|
||
periods: 使用期数
|
||
|
||
Returns:
|
||
Dict: 训练结果
|
||
"""
|
||
try:
|
||
features, labels = self.prepare_features(lottery_type, periods)
|
||
|
||
if len(features) < 20: # 数据不足
|
||
return {"success": False, "message": "数据不足,无法训练模型"}
|
||
|
||
# 标准化特征
|
||
scaler = StandardScaler()
|
||
features_scaled = scaler.fit_transform(features)
|
||
self._scalers[lottery_type] = scaler
|
||
|
||
# 为每个号码位置训练一个模型
|
||
models = {}
|
||
accuracies = []
|
||
|
||
for pos in range(labels.shape[1]):
|
||
model = RandomForestRegressor(
|
||
n_estimators=100, random_state=42)
|
||
X_train, X_test, y_train, y_test = train_test_split(
|
||
features_scaled, labels[:, pos], test_size=0.2, random_state=42
|
||
)
|
||
model.fit(X_train, y_train)
|
||
accuracy = model.score(X_test, y_test)
|
||
|
||
models[f"pos_{pos}"] = model
|
||
accuracies.append(accuracy)
|
||
|
||
self._models[lottery_type] = models
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "模型训练成功",
|
||
"avg_accuracy": np.mean(accuracies),
|
||
"accuracies": accuracies,
|
||
"training_samples": len(features)
|
||
}
|
||
|
||
except Exception as e:
|
||
return {"success": False, "message": f"训练失败: {str(e)}"}
|
||
|
||
def predict_next_numbers(self, lottery_type: str, periods: int = 10) -> Dict:
|
||
"""预测下一期号码
|
||
|
||
Args:
|
||
lottery_type: 彩票类型 ('ssq' 或 'dlt')
|
||
periods: 使用期数
|
||
|
||
Returns:
|
||
Dict: 预测结果
|
||
"""
|
||
if lottery_type not in self._models:
|
||
return {"success": False, "message": "模型未训练,请先训练模型"}
|
||
|
||
try:
|
||
model = SSQLottery if lottery_type == 'ssq' else DLTLottery
|
||
recent_draws = self.db.query(model).order_by(
|
||
model.open_time.desc()).limit(10).all() # 只需要最近10期
|
||
|
||
if len(recent_draws) < 10:
|
||
return {"success": False, "message": "历史数据不足"}
|
||
|
||
# 准备特征
|
||
feature_row = []
|
||
for draw in recent_draws:
|
||
if lottery_type == 'ssq':
|
||
numbers = [draw.red_ball_1, draw.red_ball_2, draw.red_ball_3,
|
||
draw.red_ball_4, draw.red_ball_5, draw.red_ball_6]
|
||
feature_row.extend(sorted(numbers))
|
||
feature_row.append(draw.blue_ball) # 添加蓝球
|
||
else:
|
||
numbers = [draw.front_ball_1, draw.front_ball_2, draw.front_ball_3,
|
||
draw.front_ball_4, draw.front_ball_5]
|
||
feature_row.extend(sorted(numbers))
|
||
feature_row.extend(
|
||
[draw.back_ball_1, draw.back_ball_2]) # 添加后区号码
|
||
|
||
# 标准化特征
|
||
scaler = self._scalers.get(lottery_type)
|
||
if not scaler:
|
||
return {"success": False, "message": "模型未训练,请先训练模型"}
|
||
feature_scaled = scaler.transform([feature_row])
|
||
|
||
# 预测每个位置的号码
|
||
predictions = []
|
||
models = self._models[lottery_type]
|
||
|
||
for pos in range(len(models)):
|
||
model = models[f"pos_{pos}"]
|
||
pred = model.predict(feature_scaled)[0]
|
||
predictions.append(round(pred))
|
||
|
||
# 确保预测的号码在有效范围内
|
||
if lottery_type == 'ssq':
|
||
max_red = 33
|
||
max_blue = 16
|
||
red_count = 6
|
||
else:
|
||
max_red = 35
|
||
max_blue = 12
|
||
red_count = 5
|
||
|
||
# 分离红球和蓝球预测
|
||
if lottery_type == 'ssq':
|
||
red_predictions = predictions[:red_count]
|
||
blue_prediction = predictions[-1]
|
||
else:
|
||
red_predictions = predictions[:red_count]
|
||
blue_predictions = predictions[red_count:]
|
||
|
||
# 处理红球
|
||
red_predictions = [max(1, min(max_red, p))
|
||
for p in red_predictions]
|
||
red_predictions = sorted(list(set(red_predictions)))
|
||
|
||
# 如果红球不够,补充随机号码
|
||
while len(red_predictions) < red_count:
|
||
import random
|
||
new_num = random.randint(1, max_red)
|
||
if new_num not in red_predictions:
|
||
red_predictions.append(new_num)
|
||
red_predictions = sorted(red_predictions[:red_count])
|
||
|
||
# 处理蓝球
|
||
if lottery_type == 'ssq':
|
||
blue_prediction = max(1, min(max_blue, blue_prediction))
|
||
else:
|
||
blue_predictions = [max(1, min(max_blue, p))
|
||
for p in blue_predictions]
|
||
blue_predictions = sorted(list(set(blue_predictions)))
|
||
while len(blue_predictions) < 2:
|
||
new_num = random.randint(1, max_blue)
|
||
if new_num not in blue_predictions:
|
||
blue_predictions.append(new_num)
|
||
blue_predictions = sorted(blue_predictions)
|
||
|
||
return {
|
||
"success": True,
|
||
"predicted_numbers": red_predictions,
|
||
"predicted_blue": blue_prediction if lottery_type == 'ssq' else None,
|
||
"predicted_blues": blue_predictions if lottery_type == 'dlt' else None,
|
||
"confidence": "基于历史数据的机器学习预测"
|
||
}
|
||
|
||
except Exception as e:
|
||
return {"success": False, "message": f"预测失败: {str(e)}"}
|
||
|
||
def get_pattern_based_prediction(self, lottery_type: str, periods: int = 100) -> Dict:
|
||
"""基于模式的预测
|
||
|
||
Args:
|
||
lottery_type: 彩票类型 ('ssq' 或 'dlt')
|
||
periods: 分析期数
|
||
|
||
Returns:
|
||
Dict: 预测结果
|
||
"""
|
||
model = SSQLottery if lottery_type == 'ssq' else DLTLottery
|
||
recent_draws = self.db.query(model).order_by(
|
||
model.open_time.desc()).limit(periods).all()
|
||
|
||
# 分析最近的开奖模式
|
||
patterns = {
|
||
'sum_range': [],
|
||
'odd_even_ratio': [],
|
||
'zone_distribution': [],
|
||
'consecutive_count': []
|
||
}
|
||
|
||
for draw in recent_draws:
|
||
if lottery_type == 'ssq':
|
||
numbers = [draw.red_ball_1, draw.red_ball_2, draw.red_ball_3,
|
||
draw.red_ball_4, draw.red_ball_5, draw.red_ball_6]
|
||
else:
|
||
numbers = [draw.front_ball_1, draw.front_ball_2, draw.front_ball_3,
|
||
draw.front_ball_4, draw.front_ball_5]
|
||
|
||
# 和值
|
||
patterns['sum_range'].append(sum(numbers))
|
||
|
||
# 奇偶比
|
||
odd_count = sum(1 for n in numbers if n % 2 == 1)
|
||
patterns['odd_even_ratio'].append(
|
||
f"{odd_count}:{len(numbers)-odd_count}")
|
||
|
||
# 分区分布
|
||
zones = [(n-1)//5 + 1 for n in numbers]
|
||
zone_count = len(set(zones))
|
||
patterns['zone_distribution'].append(zone_count)
|
||
|
||
# 连号数量
|
||
sorted_nums = sorted(numbers)
|
||
consecutive = sum(1 for i in range(len(sorted_nums)-1)
|
||
if sorted_nums[i+1] - sorted_nums[i] == 1)
|
||
patterns['consecutive_count'].append(consecutive)
|
||
|
||
# 计算最常见的模式
|
||
most_common_patterns = {}
|
||
for key, values in patterns.items():
|
||
if key == 'sum_range':
|
||
# 和值范围
|
||
avg_sum = np.mean(values)
|
||
std_sum = np.std(values)
|
||
most_common_patterns[key] = {
|
||
'avg': avg_sum,
|
||
'std': std_sum,
|
||
'range': [int(avg_sum - std_sum), int(avg_sum + std_sum)]
|
||
}
|
||
else:
|
||
# 其他模式
|
||
from collections import Counter
|
||
counter = Counter(values)
|
||
most_common_patterns[key] = counter.most_common(3)
|
||
|
||
# 根据模式生成推荐号码
|
||
max_num = 33 if lottery_type == 'ssq' else 35
|
||
target_count = 6 if lottery_type == 'ssq' else 5
|
||
|
||
# 获取推荐的模式
|
||
target_sum_range = most_common_patterns['sum_range']['range']
|
||
target_odd_ratio = most_common_patterns['odd_even_ratio'][0][0].split(
|
||
':')[0] if most_common_patterns['odd_even_ratio'] else "3"
|
||
target_zones = most_common_patterns['zone_distribution'][0][
|
||
0] if most_common_patterns['zone_distribution'] else 4
|
||
target_consecutive = most_common_patterns['consecutive_count'][
|
||
0][0] if most_common_patterns['consecutive_count'] else 1
|
||
|
||
# 生成符合模式的号码
|
||
import random
|
||
best_numbers = None
|
||
best_score = -1
|
||
|
||
# 尝试100次生成最符合模式的号码
|
||
for _ in range(100):
|
||
# 初始化号码集
|
||
numbers = set()
|
||
|
||
# 确保有连号
|
||
if target_consecutive > 0:
|
||
start = random.randint(1, max_num - target_consecutive)
|
||
for i in range(target_consecutive + 1):
|
||
if len(numbers) < target_count:
|
||
numbers.add(start + i)
|
||
|
||
# 根据奇偶比例添加号码
|
||
target_odd = int(target_odd_ratio)
|
||
current_odd = sum(1 for n in numbers if n % 2 == 1)
|
||
|
||
while len(numbers) < target_count:
|
||
n = random.randint(1, max_num)
|
||
if n not in numbers:
|
||
if (n % 2 == 1 and current_odd < target_odd) or \
|
||
(n % 2 == 0 and (len(numbers) - current_odd) < (target_count - target_odd)):
|
||
numbers.add(n)
|
||
if n % 2 == 1:
|
||
current_odd += 1
|
||
|
||
numbers = sorted(list(numbers))
|
||
|
||
# 计算当前号码组合的得分
|
||
score = 0
|
||
|
||
# 和值得分
|
||
current_sum = sum(numbers)
|
||
if target_sum_range[0] <= current_sum <= target_sum_range[1]:
|
||
score += 1
|
||
|
||
# 奇偶比得分
|
||
current_odd = sum(1 for n in numbers if n % 2 == 1)
|
||
if current_odd == int(target_odd_ratio):
|
||
score += 1
|
||
|
||
# 分区得分
|
||
current_zones = len(set((n-1)//5 + 1 for n in numbers))
|
||
if current_zones == target_zones:
|
||
score += 1
|
||
|
||
# 连号得分
|
||
current_consecutive = sum(1 for i in range(
|
||
len(numbers)-1) if numbers[i+1] - numbers[i] == 1)
|
||
if current_consecutive == target_consecutive:
|
||
score += 1
|
||
|
||
if score > best_score:
|
||
best_score = score
|
||
best_numbers = numbers
|
||
|
||
# 如果是双色球,还需要生成蓝球
|
||
predicted_blue = None
|
||
if lottery_type == 'ssq':
|
||
# 分析蓝球规律
|
||
blue_numbers = [draw.blue_ball for draw in recent_draws]
|
||
blue_counter = Counter(blue_numbers)
|
||
# 选择最近出现频率适中的蓝球
|
||
common_blues = [num for num, _ in blue_counter.most_common(
|
||
)[len(blue_counter)//3:(len(blue_counter)*2)//3]]
|
||
if common_blues:
|
||
predicted_blue = random.choice(common_blues)
|
||
else:
|
||
predicted_blue = random.randint(1, 16)
|
||
|
||
return {
|
||
"success": True,
|
||
"patterns": most_common_patterns,
|
||
"suggested_criteria": {
|
||
"sum_range": most_common_patterns['sum_range']['range'],
|
||
"odd_even_ratio": most_common_patterns['odd_even_ratio'][0][0] if most_common_patterns['odd_even_ratio'] else "3:3",
|
||
"zone_distribution": most_common_patterns['zone_distribution'][0][0] if most_common_patterns['zone_distribution'] else 4,
|
||
"consecutive_count": most_common_patterns['consecutive_count'][0][0] if most_common_patterns['consecutive_count'] else 1
|
||
},
|
||
"predicted_numbers": best_numbers,
|
||
"predicted_blue": predicted_blue if lottery_type == 'ssq' else None
|
||
}
|
||
|
||
def get_ensemble_prediction(self, lottery_type: str, periods: int = 100) -> Dict:
|
||
"""集成预测(结合多种方法)
|
||
|
||
Args:
|
||
lottery_type: 彩票类型 ('ssq' 或 'dlt')
|
||
periods: 分析期数
|
||
|
||
Returns:
|
||
Dict: 预测结果
|
||
"""
|
||
# 机器学习预测
|
||
ml_result = self.predict_next_numbers(lottery_type, periods)
|
||
|
||
# 模式预测
|
||
pattern_result = self.get_pattern_based_prediction(
|
||
lottery_type, periods)
|
||
|
||
# 频率预测(基于现有服务)
|
||
from .analysis_service import LotteryAnalysisService
|
||
analysis_service = LotteryAnalysisService(self.db)
|
||
freq_result = analysis_service.get_hot_cold_numbers(
|
||
lottery_type, periods)
|
||
|
||
# 综合推荐
|
||
recommendations = []
|
||
|
||
if ml_result.get('success'):
|
||
recommendations.append({
|
||
'method': '机器学习',
|
||
'numbers': ml_result['predicted_numbers'],
|
||
'blue': ml_result['predicted_blue'] if lottery_type == 'ssq' else None,
|
||
'blues': ml_result['predicted_blues'] if lottery_type == 'dlt' else None,
|
||
'confidence': '高'
|
||
})
|
||
|
||
if freq_result:
|
||
max_red = 33 if lottery_type == 'ssq' else 35
|
||
max_blue = 16 if lottery_type == 'ssq' else 12
|
||
red_count = 6 if lottery_type == 'ssq' else 5
|
||
|
||
# 获取热号和冷号
|
||
hot_reds = freq_result['hot_reds']
|
||
cold_reds = freq_result['cold_reds']
|
||
|
||
# 初始化号码集
|
||
selected_numbers = set()
|
||
|
||
# 从热号中选择2-3个号码
|
||
hot_count = min(3, len(hot_reds))
|
||
for num in hot_reds[:hot_count]:
|
||
selected_numbers.add(num)
|
||
|
||
# 从冷号中选择1-2个号码
|
||
cold_count = min(2, len(cold_reds))
|
||
for num in cold_reds[:cold_count]:
|
||
if not any(abs(num - x) == 1 for x in selected_numbers): # 避免连号
|
||
selected_numbers.add(num)
|
||
|
||
# 计算还需要多少个号码
|
||
remaining = red_count - len(selected_numbers)
|
||
|
||
# 获取温号(既不是热号也不是冷号的号码)
|
||
all_numbers = set(range(1, max_red + 1))
|
||
warm_numbers = list(all_numbers - set(hot_reds) - set(cold_reds))
|
||
import random
|
||
random.shuffle(warm_numbers)
|
||
|
||
# 从温号中补充号码
|
||
for num in warm_numbers:
|
||
if len(selected_numbers) >= red_count:
|
||
break
|
||
# 检查是否会形成连号
|
||
consecutive_count = sum(
|
||
1 for x in selected_numbers if abs(num - x) == 1)
|
||
if consecutive_count <= 1: # 最多允许两个连号
|
||
selected_numbers.add(num)
|
||
|
||
# 如果还不够,从剩余号码中随机选择
|
||
remaining_numbers = list(all_numbers - selected_numbers)
|
||
while len(selected_numbers) < red_count:
|
||
num = random.choice(remaining_numbers)
|
||
consecutive_count = sum(
|
||
1 for x in selected_numbers if abs(num - x) == 1)
|
||
if consecutive_count <= 1:
|
||
selected_numbers.add(num)
|
||
remaining_numbers.remove(num)
|
||
|
||
# 生成蓝球
|
||
if lottery_type == 'ssq':
|
||
if 'hot_blues' in freq_result and freq_result['hot_blues']:
|
||
# 从热门蓝球中随机选择
|
||
blue_prediction = random.choice(
|
||
freq_result['hot_blues'][:3])
|
||
else:
|
||
blue_prediction = random.randint(1, max_blue)
|
||
|
||
recommendations.append({
|
||
'method': '热冷号分析',
|
||
'numbers': sorted(list(selected_numbers)),
|
||
'blue': blue_prediction,
|
||
'confidence': '中'
|
||
})
|
||
else:
|
||
# 大乐透后区号码选择
|
||
blue_predictions = []
|
||
if 'hot_blues' in freq_result and freq_result['hot_blues']:
|
||
# 从热门后区号码中选择
|
||
available_blues = freq_result['hot_blues'][:4] # 取前4个热门号码
|
||
while len(blue_predictions) < 2 and available_blues:
|
||
num = random.choice(available_blues)
|
||
blue_predictions.append(num)
|
||
available_blues.remove(num)
|
||
|
||
# 如果还不够2个,随机补充
|
||
while len(blue_predictions) < 2:
|
||
num = random.randint(1, max_blue)
|
||
if num not in blue_predictions:
|
||
blue_predictions.append(num)
|
||
|
||
recommendations.append({
|
||
'method': '热冷号分析',
|
||
'numbers': sorted(list(selected_numbers)),
|
||
'blues': sorted(blue_predictions),
|
||
'confidence': '中'
|
||
})
|
||
|
||
if pattern_result and pattern_result.get('success'):
|
||
recommendations.append({
|
||
'method': '模式分析',
|
||
'numbers': pattern_result['predicted_numbers'],
|
||
'blue': pattern_result['predicted_blue'] if lottery_type == 'ssq' else None,
|
||
'blues': pattern_result['predicted_blues'] if lottery_type == 'dlt' else None,
|
||
'confidence': '中'
|
||
})
|
||
|
||
return {
|
||
"success": True,
|
||
"recommendations": recommendations,
|
||
"pattern_analysis": pattern_result.get('suggested_criteria', {}) if pattern_result else {},
|
||
"frequency_analysis": freq_result or {}
|
||
}
|