960 lines
42 KiB
Python
960 lines
42 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import json
|
|
import sqlite3
|
|
import pandas as pd
|
|
import numpy as np
|
|
from datetime import datetime, timedelta
|
|
import warnings
|
|
import os
|
|
import pickle
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
warnings.filterwarnings('ignore')
|
|
|
|
class ConfigurableEmotionalDamageStrategy:
|
|
def __init__(self, config_path='config.json'):
|
|
"""Initialize strategy with config file"""
|
|
self.load_config(config_path)
|
|
self.cash = self.config['strategy_params']['initial_capital']
|
|
self.positions = {}
|
|
self.portfolio_value = []
|
|
self.trades = []
|
|
|
|
# State management
|
|
self.state = 'QQQ_HOLD'
|
|
self.current_step = 0
|
|
self.target_allocation = {}
|
|
self.last_fear_date = None
|
|
|
|
# For gradual transitions
|
|
self.transition_plan = {}
|
|
self.transition_cash_pool = 0
|
|
|
|
def load_config(self, config_path):
|
|
"""Load configuration from JSON file"""
|
|
with open(config_path, 'r') as f:
|
|
self.config = json.load(f)
|
|
|
|
# Set strategy parameters as attributes for easy access
|
|
params = self.config['strategy_params']
|
|
self.initial_capital = params['initial_capital']
|
|
self.fear_threshold = params['fear_threshold']
|
|
self.greed_threshold = params['greed_threshold']
|
|
self.stop_loss_threshold = params['stop_loss_threshold']
|
|
self.top_stocks_count = params['top_stocks_count']
|
|
self.volatility_threshold = params['volatility_threshold']
|
|
self.volatility_calculation_days = params['volatility_calculation_days']
|
|
self.transition_steps = params['transition_steps']
|
|
|
|
# Technical indicators
|
|
tech = self.config['technical_indicators']
|
|
self.rsi_threshold = tech['rsi_threshold']
|
|
self.required_indicators = tech['required_indicators']
|
|
self.sma5_above_sma20 = tech['sma5_above_sma20']
|
|
self.macd_convergence = tech['macd_convergence']
|
|
|
|
if self.config['output_settings']['show_console_output']:
|
|
print(f"✅ 配置已加载:")
|
|
print(f" 初始资金: ${self.initial_capital:,}")
|
|
print(f" 恐慌阈值: {self.fear_threshold}")
|
|
print(f" 贪婪阈值: {self.greed_threshold}")
|
|
print(f" 止损阈值: {self.stop_loss_threshold*100}%")
|
|
print(f" 转换步数: {self.transition_steps}")
|
|
print(f" 选股数量: {self.top_stocks_count}")
|
|
print("")
|
|
|
|
def get_data(self):
|
|
"""Load Fear & Greed Index and stock data"""
|
|
db_path = self.config['paths']['database_path']
|
|
if self.config['output_settings']['show_console_output']:
|
|
print(f"连接数据库: {db_path}")
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
|
|
# Get Fear & Greed Index
|
|
fg_data = pd.read_sql_query('''
|
|
SELECT date, fear_greed_index
|
|
FROM fear_greed_index
|
|
ORDER BY date
|
|
''', conn)
|
|
fg_data['date'] = pd.to_datetime(fg_data['date'])
|
|
fg_data.set_index('date', inplace=True)
|
|
|
|
# Get real QQQ price data
|
|
qqq_data = pd.read_sql_query('''
|
|
SELECT date, close as qqq_close
|
|
FROM qqq
|
|
ORDER BY date
|
|
''', conn)
|
|
qqq_data['date'] = pd.to_datetime(qqq_data['date'])
|
|
qqq_data.set_index('date', inplace=True)
|
|
|
|
# Get available tickers
|
|
min_records = self.config['data_settings']['min_ticker_records']
|
|
cursor = conn.cursor()
|
|
cursor.execute(f'SELECT ticker FROM ticker_list WHERE records > {min_records}')
|
|
self.available_tickers = [row[0] for row in cursor.fetchall()]
|
|
|
|
conn.close()
|
|
|
|
# Merge data
|
|
self.data = pd.merge(fg_data, qqq_data, left_index=True, right_index=True, how='inner')
|
|
|
|
# Apply date filters if specified
|
|
if self.config['data_settings']['start_date']:
|
|
start_date = pd.to_datetime(self.config['data_settings']['start_date'])
|
|
self.data = self.data[self.data.index >= start_date]
|
|
|
|
if self.config['data_settings']['end_date']:
|
|
end_date = pd.to_datetime(self.config['data_settings']['end_date'])
|
|
self.data = self.data[self.data.index <= end_date]
|
|
|
|
self.data.sort_index(inplace=True)
|
|
|
|
if self.config['output_settings']['show_console_output']:
|
|
print(f"数据加载完成: {self.data.index.min().strftime('%Y-%m-%d')} 到 {self.data.index.max().strftime('%Y-%m-%d')}")
|
|
print(f"可用股票数量: {len(self.available_tickers)}")
|
|
print("")
|
|
|
|
def get_stock_price(self, ticker, date):
|
|
"""Get stock price for a specific ticker and date"""
|
|
db_path = self.config['paths']['database_path']
|
|
conn = sqlite3.connect(db_path)
|
|
|
|
query = f'''
|
|
SELECT close FROM {ticker.lower()}
|
|
WHERE date <= ?
|
|
ORDER BY date DESC
|
|
LIMIT 1
|
|
'''
|
|
|
|
cursor = conn.cursor()
|
|
cursor.execute(query, (date.strftime('%Y-%m-%d'),))
|
|
result = cursor.fetchone()
|
|
conn.close()
|
|
|
|
return result[0] if result else None
|
|
|
|
def calculate_volatility(self, ticker, current_date):
|
|
"""Calculate historical volatility over configured period"""
|
|
db_path = self.config['paths']['database_path']
|
|
conn = sqlite3.connect(db_path)
|
|
|
|
try:
|
|
start_date = current_date - timedelta(days=self.volatility_calculation_days)
|
|
|
|
query = f'''
|
|
SELECT date, close FROM {ticker.lower()}
|
|
WHERE date >= ? AND date <= ?
|
|
ORDER BY date
|
|
'''
|
|
|
|
df = pd.read_sql_query(query, conn, params=(
|
|
start_date.strftime('%Y-%m-%d'),
|
|
current_date.strftime('%Y-%m-%d')
|
|
))
|
|
|
|
if len(df) > 10:
|
|
df['returns'] = df['close'].pct_change()
|
|
volatility = df['returns'].std() * np.sqrt(252)
|
|
conn.close()
|
|
return volatility
|
|
|
|
except Exception as e:
|
|
pass
|
|
|
|
conn.close()
|
|
return 0
|
|
|
|
def check_technical_indicators(self, ticker, date):
|
|
"""Check RSI, MACD, and SMA technical indicators"""
|
|
db_path = self.config['paths']['database_path']
|
|
conn = sqlite3.connect(db_path)
|
|
|
|
try:
|
|
query = f'''
|
|
SELECT date, close FROM {ticker.lower()}
|
|
WHERE date <= ?
|
|
ORDER BY date DESC
|
|
LIMIT 50
|
|
'''
|
|
|
|
df = pd.read_sql_query(query, conn, params=(date.strftime('%Y-%m-%d'),))
|
|
|
|
if len(df) < 20:
|
|
conn.close()
|
|
return False
|
|
|
|
df = df.sort_values('date')
|
|
df.reset_index(drop=True, inplace=True)
|
|
|
|
# Calculate RSI
|
|
rsi_period = self.config['technical_indicators']['rsi_period']
|
|
delta = df['close'].diff()
|
|
gain = (delta.where(delta > 0, 0)).rolling(window=rsi_period).mean()
|
|
loss = (-delta.where(delta < 0, 0)).rolling(window=rsi_period).mean()
|
|
rs = gain / loss
|
|
rsi = 100 - (100 / (1 + rs))
|
|
|
|
# Calculate MACD
|
|
ema_fast = self.config['technical_indicators']['ema_periods']['fast']
|
|
ema_slow = self.config['technical_indicators']['ema_periods']['slow']
|
|
ema_signal = self.config['technical_indicators']['ema_periods']['signal']
|
|
|
|
ema12 = df['close'].ewm(span=ema_fast).mean()
|
|
ema26 = df['close'].ewm(span=ema_slow).mean()
|
|
macd = ema12 - ema26
|
|
signal = macd.ewm(span=ema_signal).mean()
|
|
|
|
# Calculate SMA
|
|
sma_fast = self.config['technical_indicators']['sma_periods']['fast']
|
|
sma_slow = self.config['technical_indicators']['sma_periods']['slow']
|
|
sma5 = df['close'].rolling(window=sma_fast).mean()
|
|
sma20 = df['close'].rolling(window=sma_slow).mean()
|
|
|
|
# Check conditions
|
|
latest_rsi = rsi.iloc[-1]
|
|
latest_macd = macd.iloc[-1]
|
|
latest_signal = signal.iloc[-1]
|
|
latest_sma5 = sma5.iloc[-1]
|
|
latest_sma20 = sma20.iloc[-1]
|
|
|
|
# RSI condition
|
|
rsi_ok = latest_rsi > self.rsi_threshold
|
|
|
|
# MACD condition
|
|
if self.macd_convergence and len(macd) >= 2 and len(signal) >= 2:
|
|
prev_macd = macd.iloc[-2]
|
|
prev_signal = signal.iloc[-2]
|
|
prev_diff = abs(prev_macd - prev_signal)
|
|
current_diff = abs(latest_macd - latest_signal)
|
|
macd_ok = current_diff < prev_diff # Lines are converging
|
|
else:
|
|
macd_ok = latest_macd > latest_signal # Traditional golden cross
|
|
|
|
# SMA condition
|
|
if self.sma5_above_sma20:
|
|
sma_ok = latest_sma5 > latest_sma20
|
|
else:
|
|
sma_ok = True # Skip SMA check if disabled
|
|
|
|
# Check if enough indicators are positive
|
|
conditions = [rsi_ok, macd_ok, sma_ok]
|
|
score = sum(conditions)
|
|
|
|
conn.close()
|
|
return score >= self.required_indicators
|
|
|
|
except Exception as e:
|
|
conn.close()
|
|
return False
|
|
|
|
def select_volatile_stocks(self, fear_start_date, fear_end_date):
|
|
"""Select stocks using technical indicators, then sort by volatility"""
|
|
qualified_stocks = []
|
|
|
|
for ticker in self.available_tickers:
|
|
if self.check_technical_indicators(ticker, fear_end_date):
|
|
vol = self.calculate_volatility(ticker, fear_end_date)
|
|
|
|
if vol > self.volatility_threshold:
|
|
qualified_stocks.append((ticker, vol))
|
|
|
|
# Sort by volatility and select top stocks
|
|
qualified_stocks.sort(key=lambda x: x[1], reverse=True)
|
|
top_stocks = [ticker for ticker, vol in qualified_stocks[:self.top_stocks_count]]
|
|
|
|
return top_stocks
|
|
|
|
def execute_trade(self, date, action, ticker=None, shares=None, price=None, value=None):
|
|
"""Execute and record a trade"""
|
|
fg_index = self.data.loc[date, 'fear_greed_index'] if date in self.data.index else None
|
|
total_assets = self.calculate_portfolio_value(date)
|
|
|
|
self.trades.append({
|
|
'date': date,
|
|
'action': action,
|
|
'ticker': ticker,
|
|
'shares': shares,
|
|
'price': price,
|
|
'value': value,
|
|
'fg_index': fg_index,
|
|
'cnn_fear_greed': fg_index, # Same as fg_index but with clearer name
|
|
'cash_after': self.cash,
|
|
'total_assets': total_assets,
|
|
'portfolio_state': self.state
|
|
})
|
|
|
|
def calculate_portfolio_value(self, date):
|
|
"""Calculate total portfolio value"""
|
|
total_value = self.cash
|
|
|
|
for ticker, shares in self.positions.items():
|
|
if ticker == 'QQQ':
|
|
price = self.data.loc[date, 'qqq_close']
|
|
else:
|
|
price = self.get_stock_price(ticker, date)
|
|
|
|
if price:
|
|
total_value += shares * price
|
|
|
|
return total_value
|
|
|
|
def check_stop_loss(self, date):
|
|
"""Check stop loss threshold"""
|
|
for ticker in list(self.positions.keys()):
|
|
if ticker == 'QQQ':
|
|
continue
|
|
|
|
current_price = self.get_stock_price(ticker, date)
|
|
if not current_price:
|
|
continue
|
|
|
|
# Find average buy price
|
|
buy_trades = [t for t in self.trades
|
|
if t['ticker'] == ticker and t['action'] in ['BUY_GRADUAL']]
|
|
if buy_trades:
|
|
total_cost = sum(t['price'] * t['shares'] for t in buy_trades)
|
|
total_shares = sum(t['shares'] for t in buy_trades)
|
|
avg_price = total_cost / total_shares
|
|
|
|
loss_pct = (current_price - avg_price) / avg_price
|
|
if loss_pct <= -self.stop_loss_threshold:
|
|
# Sell and buy QQQ
|
|
shares = self.positions[ticker]
|
|
value = shares * current_price
|
|
self.cash += value
|
|
del self.positions[ticker]
|
|
|
|
self.execute_trade(date, 'STOP_LOSS', ticker, shares, current_price, value)
|
|
|
|
# Buy QQQ with integer shares
|
|
qqq_price = self.data.loc[date, 'qqq_close']
|
|
qqq_shares = int(value / qqq_price)
|
|
|
|
if qqq_shares > 0:
|
|
actual_qqq_value = qqq_shares * qqq_price
|
|
self.positions['QQQ'] = self.positions.get('QQQ', 0) + qqq_shares
|
|
self.cash -= actual_qqq_value
|
|
self.execute_trade(date, 'BUY_QQQ_STOPLOSS', 'QQQ', qqq_shares, qqq_price, actual_qqq_value)
|
|
|
|
if self.config['output_settings']['show_console_output']:
|
|
print(f"{date.strftime('%Y-%m-%d')}: Stop loss triggered for {ticker}, loss: {loss_pct*100:.1f}%")
|
|
|
|
def start_transition(self, date, target_type, stocks=None):
|
|
"""Initialize transition plan"""
|
|
self.transition_plan = {'type': target_type, 'stocks': stocks}
|
|
|
|
if target_type == 'CASH':
|
|
self.transition_plan['positions_to_sell'] = {}
|
|
for ticker in self.positions:
|
|
self.transition_plan['positions_to_sell'][ticker] = self.positions[ticker]
|
|
|
|
elif target_type == 'QQQ':
|
|
cash_from_positions = 0
|
|
for ticker in self.positions:
|
|
if ticker != 'QQQ':
|
|
price = self.get_stock_price(ticker, date)
|
|
if price:
|
|
cash_from_positions += self.positions[ticker] * price
|
|
|
|
self.transition_cash_pool = self.cash + cash_from_positions
|
|
self.transition_plan['total_cash_to_invest'] = self.transition_cash_pool
|
|
self.transition_plan['positions_to_sell'] = {}
|
|
for ticker in self.positions:
|
|
if ticker != 'QQQ':
|
|
self.transition_plan['positions_to_sell'][ticker] = self.positions[ticker]
|
|
|
|
elif target_type == 'VOLATILE' and stocks:
|
|
cash_from_positions = 0
|
|
for ticker in self.positions:
|
|
if ticker != 'QQQ':
|
|
price = self.get_stock_price(ticker, date)
|
|
if price:
|
|
cash_from_positions += self.positions[ticker] * price
|
|
|
|
total_available_cash = self.cash + cash_from_positions
|
|
self.transition_plan['total_cash_to_invest'] = total_available_cash
|
|
|
|
def gradual_transition(self, date, target_type, stocks=None):
|
|
"""Handle gradual transitions with integer shares"""
|
|
step_size = 1.0 / self.transition_steps
|
|
|
|
if target_type == 'CASH':
|
|
for ticker in list(self.transition_plan.get('positions_to_sell', {})):
|
|
if ticker in self.positions:
|
|
total_shares_to_sell = self.transition_plan['positions_to_sell'][ticker]
|
|
shares_to_sell = int(total_shares_to_sell * step_size)
|
|
|
|
if shares_to_sell > 0 and shares_to_sell <= self.positions[ticker]:
|
|
price = self.get_stock_price(ticker, date)
|
|
if price:
|
|
value = shares_to_sell * price
|
|
self.cash += value
|
|
self.positions[ticker] -= shares_to_sell
|
|
if self.positions[ticker] <= 0:
|
|
del self.positions[ticker]
|
|
self.execute_trade(date, 'SELL_GRADUAL', ticker, shares_to_sell, price, value)
|
|
|
|
elif target_type == 'VOLATILE' and stocks:
|
|
total_cash = self.transition_plan.get('total_cash_to_invest', 0)
|
|
cash_this_step = total_cash * step_size
|
|
|
|
if cash_this_step > 0 and self.cash >= cash_this_step:
|
|
current_step_index = min(self.current_step, len(stocks) - 1)
|
|
ticker = stocks[current_step_index]
|
|
|
|
price = self.get_stock_price(ticker, date)
|
|
if price and cash_this_step > 0:
|
|
shares = int(cash_this_step / price)
|
|
if shares > 0:
|
|
actual_value = shares * price
|
|
self.positions[ticker] = self.positions.get(ticker, 0) + shares
|
|
self.cash -= actual_value
|
|
self.execute_trade(date, 'BUY_GRADUAL', ticker, shares, price, actual_value)
|
|
|
|
elif target_type == 'QQQ':
|
|
# Sell positions gradually
|
|
for ticker in list(self.transition_plan.get('positions_to_sell', {})):
|
|
if ticker in self.positions:
|
|
total_shares_to_sell = self.transition_plan['positions_to_sell'][ticker]
|
|
shares_to_sell = int(total_shares_to_sell * step_size)
|
|
|
|
if shares_to_sell > 0 and shares_to_sell <= self.positions[ticker]:
|
|
price = self.get_stock_price(ticker, date)
|
|
if price:
|
|
value = shares_to_sell * price
|
|
self.cash += value
|
|
self.positions[ticker] -= shares_to_sell
|
|
if self.positions[ticker] <= 0:
|
|
del self.positions[ticker]
|
|
self.execute_trade(date, 'SELL_GRADUAL', ticker, shares_to_sell, price, value)
|
|
|
|
# Buy QQQ
|
|
total_cash = self.transition_plan.get('total_cash_to_invest', 0)
|
|
cash_this_step = total_cash * step_size
|
|
|
|
if cash_this_step > 0 and self.cash >= cash_this_step:
|
|
qqq_price = self.data.loc[date, 'qqq_close']
|
|
qqq_shares = int(cash_this_step / qqq_price)
|
|
|
|
if qqq_shares > 0:
|
|
actual_value = qqq_shares * qqq_price
|
|
self.positions['QQQ'] = self.positions.get('QQQ', 0) + qqq_shares
|
|
self.cash -= actual_value
|
|
self.execute_trade(date, 'BUY_GRADUAL', 'QQQ', qqq_shares, qqq_price, actual_value)
|
|
|
|
def run_backtest(self):
|
|
"""Run the strategy backtest"""
|
|
if self.config['output_settings']['show_console_output']:
|
|
print("🚀 开始运行Enhanced Emotional Damage Strategy...")
|
|
print("")
|
|
|
|
self.get_data()
|
|
|
|
# Start with 100% QQQ
|
|
first_date = self.data.index[0]
|
|
qqq_price = self.data.loc[first_date, 'qqq_close']
|
|
qqq_shares = int(self.initial_capital / qqq_price)
|
|
self.positions['QQQ'] = qqq_shares
|
|
self.cash = self.initial_capital - (qqq_shares * qqq_price)
|
|
|
|
fear_start_date = None
|
|
|
|
for date, row in self.data.iterrows():
|
|
fg_index = row['fear_greed_index']
|
|
|
|
# Check stop loss
|
|
self.check_stop_loss(date)
|
|
|
|
if self.state == 'QQQ_HOLD':
|
|
if fg_index < self.fear_threshold:
|
|
fear_start_date = date
|
|
self.state = 'FEAR_TRANSITION'
|
|
self.current_step = 0
|
|
self.start_transition(date, 'CASH')
|
|
if self.config['output_settings']['show_console_output']:
|
|
print(f"{date.strftime('%Y-%m-%d')}: Fear threshold hit ({fg_index:.1f}), starting transition to cash")
|
|
|
|
elif self.state == 'FEAR_TRANSITION':
|
|
self.gradual_transition(date, 'CASH')
|
|
self.current_step += 1
|
|
|
|
if self.current_step >= self.transition_steps:
|
|
self.state = 'CASH_WAIT'
|
|
if self.config['output_settings']['show_console_output']:
|
|
print(f"{date.strftime('%Y-%m-%d')}: Transition to cash complete")
|
|
|
|
elif self.state == 'CASH_WAIT':
|
|
if fg_index >= self.fear_threshold and fear_start_date:
|
|
top_stocks = self.select_volatile_stocks(fear_start_date, date)
|
|
|
|
if top_stocks:
|
|
self.state = 'GREED_TRANSITION'
|
|
self.current_step = 0
|
|
self.transition_stocks = top_stocks
|
|
self.start_transition(date, 'VOLATILE', top_stocks)
|
|
if self.config['output_settings']['show_console_output']:
|
|
print(f"{date.strftime('%Y-%m-%d')}: Fear recovered, starting transition to volatile stocks: {top_stocks}")
|
|
else:
|
|
self.state = 'QQQ_TRANSITION'
|
|
self.current_step = 0
|
|
self.start_transition(date, 'QQQ')
|
|
if self.config['output_settings']['show_console_output']:
|
|
print(f"{date.strftime('%Y-%m-%d')}: Fear recovered, no suitable stocks, returning to QQQ")
|
|
|
|
elif self.state == 'GREED_TRANSITION':
|
|
self.gradual_transition(date, 'VOLATILE', self.transition_stocks)
|
|
self.current_step += 1
|
|
|
|
if self.current_step >= self.transition_steps:
|
|
self.state = 'VOLATILE_STOCKS'
|
|
if self.config['output_settings']['show_console_output']:
|
|
print(f"{date.strftime('%Y-%m-%d')}: Transition to volatile stocks complete")
|
|
|
|
elif self.state == 'VOLATILE_STOCKS':
|
|
if fg_index > self.greed_threshold:
|
|
self.state = 'QQQ_TRANSITION'
|
|
self.current_step = 0
|
|
self.start_transition(date, 'QQQ')
|
|
if self.config['output_settings']['show_console_output']:
|
|
print(f"{date.strftime('%Y-%m-%d')}: Greed threshold hit ({fg_index:.1f}), starting transition to QQQ")
|
|
|
|
elif self.state == 'QQQ_TRANSITION':
|
|
self.gradual_transition(date, 'QQQ')
|
|
self.current_step += 1
|
|
|
|
if self.current_step >= self.transition_steps:
|
|
self.state = 'QQQ_HOLD'
|
|
if self.config['output_settings']['show_console_output']:
|
|
print(f"{date.strftime('%Y-%m-%d')}: Transition to QQQ complete")
|
|
|
|
# Record portfolio value
|
|
portfolio_value = self.calculate_portfolio_value(date)
|
|
self.portfolio_value.append({
|
|
'date': date,
|
|
'value': portfolio_value,
|
|
'state': self.state,
|
|
'fg_index': fg_index
|
|
})
|
|
|
|
if self.config['output_settings']['show_console_output']:
|
|
print("")
|
|
print(f"✅ 回测完成! 总交易数: {len(self.trades)}")
|
|
print("")
|
|
|
|
def calculate_performance_metrics(self, returns):
|
|
"""Calculate performance metrics"""
|
|
total_return = (returns.iloc[-1] / returns.iloc[0] - 1) * 100
|
|
annual_return = ((returns.iloc[-1] / returns.iloc[0]) ** (252 / len(returns)) - 1) * 100
|
|
|
|
# Calculate max drawdown
|
|
peak = returns.expanding().max()
|
|
drawdown = (returns - peak) / peak
|
|
max_drawdown = drawdown.min() * 100
|
|
|
|
# Calculate Sharpe ratio
|
|
daily_returns = returns.pct_change().dropna()
|
|
sharpe_ratio = np.sqrt(252) * daily_returns.mean() / daily_returns.std()
|
|
|
|
# Annual returns by year
|
|
annual_rets = {}
|
|
for year in returns.index.year.unique():
|
|
year_data = returns[returns.index.year == year]
|
|
if len(year_data) > 1:
|
|
year_return = (year_data.iloc[-1] / year_data.iloc[0] - 1) * 100
|
|
annual_rets[year] = year_return
|
|
|
|
return {
|
|
'total_return': total_return,
|
|
'annual_return': annual_return,
|
|
'max_drawdown': max_drawdown,
|
|
'sharpe_ratio': sharpe_ratio,
|
|
'annual_returns': annual_rets
|
|
}
|
|
|
|
|
|
def generate_reports(strategy):
|
|
"""Generate all reports based on config settings"""
|
|
config = strategy.config
|
|
|
|
# Create output directories
|
|
reports_dir = config['paths']['reports_dir']
|
|
results_dir = config['paths']['results_dir']
|
|
os.makedirs(reports_dir, exist_ok=True)
|
|
os.makedirs(results_dir, exist_ok=True)
|
|
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
|
|
# Save strategy object if requested
|
|
if config['output_settings']['save_strategy_object']:
|
|
strategy_file = os.path.join(results_dir, f'enhanced_strategy_{timestamp}.pkl')
|
|
with open(strategy_file, 'wb') as f:
|
|
pickle.dump(strategy, f)
|
|
print(f"📦 策略对象已保存: {strategy_file}")
|
|
|
|
# Generate CSV reports
|
|
if config['output_settings']['generate_csv'] or config['output_settings']['generate_detailed_trades']:
|
|
trades_df = pd.DataFrame(strategy.trades)
|
|
|
|
if config['output_settings']['generate_csv']:
|
|
csv_file = os.path.join(reports_dir, f'enhanced_trades_{timestamp}.csv')
|
|
trades_df.to_csv(csv_file, index=False)
|
|
print(f"📊 交易CSV已保存: {csv_file}")
|
|
|
|
if config['output_settings']['generate_detailed_trades']:
|
|
detailed_file = os.path.join(reports_dir, f'detailed_trades_{timestamp}.txt')
|
|
with open(detailed_file, 'w', encoding='utf-8') as f:
|
|
f.write("Enhanced Emotional Damage Strategy - Detailed Trades Report\n")
|
|
f.write("=" * 120 + "\n\n")
|
|
|
|
# Summary
|
|
f.write(f"📊 交易摘要:\n")
|
|
f.write(f"总交易数: {len(strategy.trades)}\n")
|
|
f.write(f"交易时间: {trades_df['date'].min().strftime('%Y-%m-%d')} 到 {trades_df['date'].max().strftime('%Y-%m-%d')}\n")
|
|
|
|
# Trade types
|
|
action_counts = trades_df['action'].value_counts()
|
|
f.write(f"\n交易类型统计:\n")
|
|
for action, count in action_counts.items():
|
|
f.write(f" {action}: {count}\n")
|
|
|
|
# Detailed trades
|
|
f.write(f"\n📋 详细交易记录:\n")
|
|
f.write("=" * 150 + "\n")
|
|
f.write(f"{'No':>3s} {'Date':>10s} {'Action':>15s} {'Ticker':>5s} {'Shares':>8s} {'Price':>7s} {'Value':>12s} {'F&G':>4s} {'Cash':>12s} {'Assets':>12s} {'State':>15s}\n")
|
|
f.write("=" * 150 + "\n")
|
|
|
|
for i, trade in enumerate(strategy.trades, 1):
|
|
fg_str = f"{trade.get('fg_index', 0):.0f}" if trade.get('fg_index') else "N/A"
|
|
cash_str = f"${trade.get('cash_after', 0):,.0f}" if trade.get('cash_after') else "N/A"
|
|
assets_str = f"${trade.get('total_assets', 0):,.0f}" if trade.get('total_assets') else "N/A"
|
|
state_str = trade.get('portfolio_state', 'N/A')
|
|
f.write(f"{i:3d} {trade['date'].strftime('%Y-%m-%d'):>10s} {trade['action']:>15s} {trade['ticker']:>5s} "
|
|
f"{trade['shares']:>8.0f} ${trade['price']:>7.2f} ${trade['value']:>12,.0f} {fg_str:>4s} {cash_str:>12s} {assets_str:>12s} {state_str:>15s}\n")
|
|
|
|
print(f"📝 详细交易报告已保存: {detailed_file}")
|
|
|
|
# Generate PDF report
|
|
if config['output_settings']['generate_pdf']:
|
|
try:
|
|
generate_enhanced_pdf_report(strategy, reports_dir, timestamp)
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ PDF生成失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
print("\n🎉 所有报告生成完成!")
|
|
|
|
|
|
def generate_enhanced_pdf_report(strategy, reports_dir, timestamp):
|
|
"""Generate comprehensive PDF report with enhanced layout and proper spacing"""
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.dates as mdates
|
|
from matplotlib.backends.backend_pdf import PdfPages
|
|
import seaborn as sns
|
|
import sqlite3
|
|
import os
|
|
|
|
# Prepare data
|
|
portfolio_df = pd.DataFrame(strategy.portfolio_value)
|
|
portfolio_df.set_index('date', inplace=True)
|
|
trades_df = pd.DataFrame(strategy.trades)
|
|
|
|
# Get benchmark data
|
|
db_path = strategy.config['paths']['database_path']
|
|
conn = sqlite3.connect(db_path)
|
|
|
|
qqq_data = pd.read_sql_query('''
|
|
SELECT date, close as qqq_close
|
|
FROM qqq
|
|
ORDER BY date
|
|
''', conn)
|
|
qqq_data['date'] = pd.to_datetime(qqq_data['date'])
|
|
qqq_data.set_index('date', inplace=True)
|
|
|
|
spy_data = pd.read_sql_query('''
|
|
SELECT date, spy_close
|
|
FROM fear_greed_data
|
|
ORDER BY date
|
|
''', conn)
|
|
spy_data['date'] = pd.to_datetime(spy_data['date'])
|
|
spy_data.set_index('date', inplace=True)
|
|
|
|
conn.close()
|
|
|
|
# Merge and align data
|
|
benchmark_data = pd.merge(qqq_data, spy_data, left_index=True, right_index=True, how='inner')
|
|
common_dates = portfolio_df.index.intersection(benchmark_data.index)
|
|
portfolio_df = portfolio_df.loc[common_dates]
|
|
benchmark_data = benchmark_data.loc[common_dates]
|
|
|
|
# Normalize benchmarks
|
|
start_value = strategy.initial_capital
|
|
benchmark_data['qqq_value'] = start_value * (benchmark_data['qqq_close'] / benchmark_data['qqq_close'].iloc[0])
|
|
benchmark_data['spy_value'] = start_value * (benchmark_data['spy_close'] / benchmark_data['spy_close'].iloc[0])
|
|
|
|
# Calculate metrics
|
|
strategy_metrics = strategy.calculate_performance_metrics(portfolio_df['value'])
|
|
qqq_metrics = strategy.calculate_performance_metrics(benchmark_data['qqq_value'])
|
|
spy_metrics = strategy.calculate_performance_metrics(benchmark_data['spy_value'])
|
|
|
|
# Find max drawdown year
|
|
def find_max_drawdown_year(returns):
|
|
peak = returns.expanding().max()
|
|
drawdown = (returns - peak) / peak
|
|
max_dd_date = drawdown.idxmin()
|
|
return max_dd_date.year
|
|
|
|
strategy_dd_year = find_max_drawdown_year(portfolio_df['value'])
|
|
qqq_dd_year = find_max_drawdown_year(benchmark_data['qqq_value'])
|
|
spy_dd_year = find_max_drawdown_year(benchmark_data['spy_value'])
|
|
|
|
# Create PDF with multiple pages
|
|
pdf_file = os.path.join(reports_dir, f'enhanced_strategy_report_{timestamp}.pdf')
|
|
|
|
with PdfPages(pdf_file) as pdf:
|
|
# Set global font to support text and better spacing
|
|
plt.rcParams['font.size'] = 10
|
|
plt.rcParams['axes.titlesize'] = 12
|
|
plt.rcParams['axes.labelsize'] = 10
|
|
plt.rcParams['xtick.labelsize'] = 8
|
|
plt.rcParams['ytick.labelsize'] = 8
|
|
plt.rcParams['legend.fontsize'] = 8
|
|
plt.rcParams['figure.titlesize'] = 14
|
|
|
|
# Page 1: Performance Comparison (Full Width)
|
|
fig1 = plt.figure(figsize=(8.5, 11))
|
|
fig1.suptitle('Enhanced Emotional Damage Strategy Report', fontsize=16, fontweight='bold', y=0.96)
|
|
|
|
# 1. Total Return Curve (Full width)
|
|
ax1 = plt.subplot(4, 1, 1)
|
|
ax1.plot(portfolio_df.index, portfolio_df['value'] / 1000,
|
|
label='Enhanced Strategy', linewidth=2, color='red')
|
|
ax1.plot(benchmark_data.index, benchmark_data['qqq_value'] / 1000,
|
|
label='QQQ', linewidth=2, color='blue')
|
|
ax1.plot(benchmark_data.index, benchmark_data['spy_value'] / 1000,
|
|
label='SPY', linewidth=2, color='green')
|
|
ax1.set_title('Portfolio Performance Comparison', fontsize=14, fontweight='bold', pad=25)
|
|
ax1.set_ylabel('Portfolio Value ($K)', fontsize=11)
|
|
ax1.legend(fontsize=10, loc='upper left')
|
|
ax1.grid(True, alpha=0.3)
|
|
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
|
|
for label in ax1.get_xticklabels():
|
|
label.set_rotation(45)
|
|
|
|
# 2. Performance Metrics Table (Full width)
|
|
ax2 = plt.subplot(4, 1, 2)
|
|
ax2.axis('off')
|
|
|
|
metrics_data = [
|
|
['Metric', 'Enhanced Strategy', 'QQQ', 'SPY'],
|
|
['Total Return', f"{strategy_metrics['total_return']:.1f}%",
|
|
f"{qqq_metrics['total_return']:.1f}%", f"{spy_metrics['total_return']:.1f}%"],
|
|
['Annual Return', f"{strategy_metrics['annual_return']:.1f}%",
|
|
f"{qqq_metrics['annual_return']:.1f}%", f"{spy_metrics['annual_return']:.1f}%"],
|
|
['Max Drawdown', f"{strategy_metrics['max_drawdown']:.1f}%",
|
|
f"{qqq_metrics['max_drawdown']:.1f}%", f"{spy_metrics['max_drawdown']:.1f}%"],
|
|
['Max DD Year', str(strategy_dd_year), str(qqq_dd_year), str(spy_dd_year)],
|
|
['Sharpe Ratio', f"{strategy_metrics['sharpe_ratio']:.2f}",
|
|
f"{qqq_metrics['sharpe_ratio']:.2f}", f"{spy_metrics['sharpe_ratio']:.2f}"],
|
|
['Total Trades', f"{len(strategy.trades)}", 'N/A', 'N/A']
|
|
]
|
|
|
|
table = ax2.table(cellText=metrics_data, cellLoc='center', loc='center')
|
|
table.auto_set_font_size(False)
|
|
table.set_fontsize(10)
|
|
table.scale(1.2, 2.0) # More height for readability
|
|
|
|
for i in range(len(metrics_data[0])):
|
|
table[(0, i)].set_facecolor('#40466e')
|
|
table[(0, i)].set_text_props(weight='bold', color='white')
|
|
|
|
ax2.set_title('Performance Metrics Comparison', fontsize=14, fontweight='bold', pad=25)
|
|
|
|
# 3. Strategy State Timeline (Full width)
|
|
ax3 = plt.subplot(4, 1, 3)
|
|
|
|
# Create state mapping and colors
|
|
state_colors = {
|
|
'QQQ_HOLD': 'blue',
|
|
'FEAR_TRANSITION': 'orange',
|
|
'CASH_WAIT': 'gray',
|
|
'GREED_TRANSITION': 'yellow',
|
|
'VOLATILE_STOCKS': 'red',
|
|
'QQQ_TRANSITION': 'green'
|
|
}
|
|
|
|
# Plot portfolio value with state colors
|
|
for state, color in state_colors.items():
|
|
state_data = portfolio_df[portfolio_df['state'] == state]
|
|
if not state_data.empty:
|
|
ax3.scatter(state_data.index, state_data['value'] / 1000,
|
|
c=color, s=2, alpha=0.8, label=state)
|
|
|
|
# Add stop-loss markers
|
|
stop_loss_trades = trades_df[trades_df['action'] == 'STOP_LOSS']
|
|
if not stop_loss_trades.empty:
|
|
for _, trade in stop_loss_trades.iterrows():
|
|
ax3.axvline(x=trade['date'], color='red', linestyle='--', alpha=0.8, linewidth=1)
|
|
|
|
ax3.set_title('Strategy State Timeline with Stop-Loss Events', fontsize=14, fontweight='bold', pad=25)
|
|
ax3.set_ylabel('Total Assets ($K)', fontsize=11)
|
|
ax3.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8)
|
|
ax3.grid(True, alpha=0.3)
|
|
ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
|
|
for label in ax3.get_xticklabels():
|
|
label.set_rotation(45)
|
|
|
|
# 4. Annual Returns Comparison (Full width)
|
|
ax4 = plt.subplot(4, 1, 4)
|
|
years = list(strategy_metrics['annual_returns'].keys())
|
|
enhanced_returns = list(strategy_metrics['annual_returns'].values())
|
|
qqq_returns = [qqq_metrics['annual_returns'].get(year, 0) for year in years]
|
|
spy_returns = [spy_metrics['annual_returns'].get(year, 0) for year in years]
|
|
|
|
x = np.arange(len(years))
|
|
width = 0.25
|
|
|
|
ax4.bar(x - width, enhanced_returns, width, label='Enhanced Strategy', color='red', alpha=0.8)
|
|
ax4.bar(x, qqq_returns, width, label='QQQ', color='blue', alpha=0.8)
|
|
ax4.bar(x + width, spy_returns, width, label='SPY', color='green', alpha=0.8)
|
|
|
|
ax4.set_title('Annual Returns Comparison by Year', fontsize=14, fontweight='bold', pad=25)
|
|
ax4.set_ylabel('Annual Return (%)', fontsize=11)
|
|
ax4.set_xlabel('Year', fontsize=11)
|
|
ax4.set_xticks(x)
|
|
ax4.set_xticklabels(years)
|
|
for label in ax4.get_xticklabels():
|
|
label.set_rotation(90) # Vertical text for years
|
|
ax4.legend(fontsize=10)
|
|
ax4.grid(True, alpha=0.3, axis='y')
|
|
|
|
plt.subplots_adjust(left=0.1, right=0.85, top=0.90, bottom=0.08, hspace=0.6)
|
|
pdf.savefig(fig1, bbox_inches='tight', dpi=150)
|
|
plt.close()
|
|
|
|
# Page 2: Detailed Analysis
|
|
fig2 = plt.figure(figsize=(8.5, 11))
|
|
fig2.suptitle('Detailed Trading and Risk Analysis', fontsize=16, fontweight='bold', y=0.95)
|
|
|
|
# 5. Stop-Loss Analysis (Full width)
|
|
ax5 = plt.subplot(3, 1, 1)
|
|
|
|
if not stop_loss_trades.empty:
|
|
stop_loss_trades_copy = stop_loss_trades.copy()
|
|
stop_loss_trades_copy['year'] = stop_loss_trades_copy['date'].dt.year
|
|
stop_loss_by_year = stop_loss_trades_copy.groupby('year').size()
|
|
|
|
ax5.bar(stop_loss_by_year.index, stop_loss_by_year.values, color='red', alpha=0.8, width=0.6)
|
|
ax5.set_title('Stop-Loss Triggers by Year', fontsize=14, fontweight='bold', pad=30)
|
|
ax5.set_ylabel('Number of Stop-Loss Events', fontsize=11)
|
|
ax5.set_xlabel('Year', fontsize=11)
|
|
ax5.grid(True, alpha=0.3, axis='y')
|
|
for label in ax5.get_xticklabels():
|
|
label.set_rotation(45)
|
|
else:
|
|
ax5.text(0.5, 0.5, 'No Stop-Loss Events Triggered', ha='center', va='center',
|
|
transform=ax5.transAxes, fontsize=14, fontweight='bold')
|
|
ax5.set_title('Stop-Loss Analysis', fontsize=14, fontweight='bold', pad=30)
|
|
|
|
# 6. Trade Frequency Analysis (Full width)
|
|
ax6 = plt.subplot(3, 1, 2)
|
|
|
|
trades_df_copy = trades_df.copy()
|
|
trades_df_copy['year'] = trades_df_copy['date'].dt.year
|
|
trade_frequency = trades_df_copy.groupby('year').size()
|
|
|
|
ax6.bar(trade_frequency.index, trade_frequency.values, color='purple', alpha=0.8, width=0.6)
|
|
ax6.set_title('Trading Activity by Year', fontsize=14, fontweight='bold', pad=30)
|
|
ax6.set_ylabel('Number of Trades', fontsize=11)
|
|
ax6.set_xlabel('Year', fontsize=11)
|
|
ax6.grid(True, alpha=0.3, axis='y')
|
|
for label in ax6.get_xticklabels():
|
|
label.set_rotation(45)
|
|
|
|
# 7. Fear & Greed Index with Trading Signals (Full width)
|
|
ax7 = plt.subplot(3, 1, 3)
|
|
|
|
# Plot Fear & Greed Index
|
|
fg_data = portfolio_df['fg_index'].dropna()
|
|
ax7.plot(fg_data.index, fg_data.values, color='purple', alpha=0.8, linewidth=1.5)
|
|
ax7.axhline(y=25, color='red', linestyle='--', alpha=0.7, linewidth=2, label='Fear Threshold (25)')
|
|
ax7.axhline(y=75, color='green', linestyle='--', alpha=0.7, linewidth=2, label='Greed Threshold (75)')
|
|
ax7.fill_between(fg_data.index, 0, 25, alpha=0.2, color='red', label='Fear Zone')
|
|
ax7.fill_between(fg_data.index, 75, 100, alpha=0.2, color='green', label='Greed Zone')
|
|
|
|
# Add trade markers
|
|
buy_trades = trades_df[trades_df['action'].str.contains('BUY')]
|
|
sell_trades = trades_df[trades_df['action'].str.contains('SELL')]
|
|
|
|
if not buy_trades.empty:
|
|
ax7.scatter(buy_trades['date'], buy_trades['fg_index'],
|
|
color='darkgreen', s=15, alpha=0.8, marker='^', label='Buy Signals', zorder=5)
|
|
if not sell_trades.empty:
|
|
ax7.scatter(sell_trades['date'], sell_trades['fg_index'],
|
|
color='darkred', s=15, alpha=0.8, marker='v', label='Sell Signals', zorder=5)
|
|
|
|
ax7.set_title('Fear & Greed Index with Trading Signals', fontsize=14, fontweight='bold', pad=30)
|
|
ax7.set_ylabel('CNN Fear & Greed Index', fontsize=11)
|
|
ax7.set_xlabel('Date', fontsize=11)
|
|
ax7.set_ylim(0, 100)
|
|
ax7.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8)
|
|
ax7.grid(True, alpha=0.3)
|
|
ax7.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
|
|
for label in ax7.get_xticklabels():
|
|
label.set_rotation(45)
|
|
|
|
plt.subplots_adjust(left=0.1, right=0.85, top=0.88, bottom=0.10, hspace=1.0)
|
|
pdf.savefig(fig2, bbox_inches='tight', dpi=150)
|
|
plt.close()
|
|
|
|
# Test PDF readability by checking file size and trying to read it
|
|
try:
|
|
file_size = os.path.getsize(pdf_file)
|
|
if file_size < 50000: # Less than 50KB might indicate issues
|
|
print(f"⚠️ Warning: PDF file size seems small ({file_size} bytes)")
|
|
else:
|
|
print(f"📈 PDF报告已保存: {pdf_file} (Size: {file_size:,} bytes)")
|
|
|
|
# Quick validation - ensure we can open the file
|
|
with open(pdf_file, 'rb') as f:
|
|
header = f.read(10)
|
|
if not header.startswith(b'%PDF'):
|
|
print(f"⚠️ Warning: Generated file may not be a valid PDF")
|
|
else:
|
|
print(f"✅ PDF file validation passed")
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Error validating PDF: {e}")
|
|
|
|
print(f"📈 PDF报告已保存: {pdf_file}")
|
|
|
|
|
|
def main():
|
|
"""Main function to run strategy with config"""
|
|
config_file = 'config.json'
|
|
|
|
if not os.path.exists(config_file):
|
|
print(f"❌ 配置文件未找到: {config_file}")
|
|
print("请确保config.json文件存在于当前目录")
|
|
return
|
|
|
|
try:
|
|
# Initialize and run strategy
|
|
strategy = ConfigurableEmotionalDamageStrategy(config_file)
|
|
strategy.run_backtest()
|
|
|
|
# Generate reports
|
|
generate_reports(strategy)
|
|
|
|
except Exception as e:
|
|
print(f"❌ 运行失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |