Files
docker-configs/backtest/strategy/emotional-damage/run_strategy_with_config.py
2025-07-19 00:00:01 -05:00

960 lines
42 KiB
Python

#!/usr/bin/env python3
import json
import sqlite3
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import warnings
import os
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
warnings.filterwarnings('ignore')
class ConfigurableEmotionalDamageStrategy:
def __init__(self, config_path='config.json'):
"""Initialize strategy with config file"""
self.load_config(config_path)
self.cash = self.config['strategy_params']['initial_capital']
self.positions = {}
self.portfolio_value = []
self.trades = []
# State management
self.state = 'QQQ_HOLD'
self.current_step = 0
self.target_allocation = {}
self.last_fear_date = None
# For gradual transitions
self.transition_plan = {}
self.transition_cash_pool = 0
def load_config(self, config_path):
"""Load configuration from JSON file"""
with open(config_path, 'r') as f:
self.config = json.load(f)
# Set strategy parameters as attributes for easy access
params = self.config['strategy_params']
self.initial_capital = params['initial_capital']
self.fear_threshold = params['fear_threshold']
self.greed_threshold = params['greed_threshold']
self.stop_loss_threshold = params['stop_loss_threshold']
self.top_stocks_count = params['top_stocks_count']
self.volatility_threshold = params['volatility_threshold']
self.volatility_calculation_days = params['volatility_calculation_days']
self.transition_steps = params['transition_steps']
# Technical indicators
tech = self.config['technical_indicators']
self.rsi_threshold = tech['rsi_threshold']
self.required_indicators = tech['required_indicators']
self.sma5_above_sma20 = tech['sma5_above_sma20']
self.macd_convergence = tech['macd_convergence']
if self.config['output_settings']['show_console_output']:
print(f"✅ 配置已加载:")
print(f" 初始资金: ${self.initial_capital:,}")
print(f" 恐慌阈值: {self.fear_threshold}")
print(f" 贪婪阈值: {self.greed_threshold}")
print(f" 止损阈值: {self.stop_loss_threshold*100}%")
print(f" 转换步数: {self.transition_steps}")
print(f" 选股数量: {self.top_stocks_count}")
print("")
def get_data(self):
"""Load Fear & Greed Index and stock data"""
db_path = self.config['paths']['database_path']
if self.config['output_settings']['show_console_output']:
print(f"连接数据库: {db_path}")
conn = sqlite3.connect(db_path)
# Get Fear & Greed Index
fg_data = pd.read_sql_query('''
SELECT date, fear_greed_index
FROM fear_greed_index
ORDER BY date
''', conn)
fg_data['date'] = pd.to_datetime(fg_data['date'])
fg_data.set_index('date', inplace=True)
# Get real QQQ price data
qqq_data = pd.read_sql_query('''
SELECT date, close as qqq_close
FROM qqq
ORDER BY date
''', conn)
qqq_data['date'] = pd.to_datetime(qqq_data['date'])
qqq_data.set_index('date', inplace=True)
# Get available tickers
min_records = self.config['data_settings']['min_ticker_records']
cursor = conn.cursor()
cursor.execute(f'SELECT ticker FROM ticker_list WHERE records > {min_records}')
self.available_tickers = [row[0] for row in cursor.fetchall()]
conn.close()
# Merge data
self.data = pd.merge(fg_data, qqq_data, left_index=True, right_index=True, how='inner')
# Apply date filters if specified
if self.config['data_settings']['start_date']:
start_date = pd.to_datetime(self.config['data_settings']['start_date'])
self.data = self.data[self.data.index >= start_date]
if self.config['data_settings']['end_date']:
end_date = pd.to_datetime(self.config['data_settings']['end_date'])
self.data = self.data[self.data.index <= end_date]
self.data.sort_index(inplace=True)
if self.config['output_settings']['show_console_output']:
print(f"数据加载完成: {self.data.index.min().strftime('%Y-%m-%d')}{self.data.index.max().strftime('%Y-%m-%d')}")
print(f"可用股票数量: {len(self.available_tickers)}")
print("")
def get_stock_price(self, ticker, date):
"""Get stock price for a specific ticker and date"""
db_path = self.config['paths']['database_path']
conn = sqlite3.connect(db_path)
query = f'''
SELECT close FROM {ticker.lower()}
WHERE date <= ?
ORDER BY date DESC
LIMIT 1
'''
cursor = conn.cursor()
cursor.execute(query, (date.strftime('%Y-%m-%d'),))
result = cursor.fetchone()
conn.close()
return result[0] if result else None
def calculate_volatility(self, ticker, current_date):
"""Calculate historical volatility over configured period"""
db_path = self.config['paths']['database_path']
conn = sqlite3.connect(db_path)
try:
start_date = current_date - timedelta(days=self.volatility_calculation_days)
query = f'''
SELECT date, close FROM {ticker.lower()}
WHERE date >= ? AND date <= ?
ORDER BY date
'''
df = pd.read_sql_query(query, conn, params=(
start_date.strftime('%Y-%m-%d'),
current_date.strftime('%Y-%m-%d')
))
if len(df) > 10:
df['returns'] = df['close'].pct_change()
volatility = df['returns'].std() * np.sqrt(252)
conn.close()
return volatility
except Exception as e:
pass
conn.close()
return 0
def check_technical_indicators(self, ticker, date):
"""Check RSI, MACD, and SMA technical indicators"""
db_path = self.config['paths']['database_path']
conn = sqlite3.connect(db_path)
try:
query = f'''
SELECT date, close FROM {ticker.lower()}
WHERE date <= ?
ORDER BY date DESC
LIMIT 50
'''
df = pd.read_sql_query(query, conn, params=(date.strftime('%Y-%m-%d'),))
if len(df) < 20:
conn.close()
return False
df = df.sort_values('date')
df.reset_index(drop=True, inplace=True)
# Calculate RSI
rsi_period = self.config['technical_indicators']['rsi_period']
delta = df['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(window=rsi_period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=rsi_period).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
# Calculate MACD
ema_fast = self.config['technical_indicators']['ema_periods']['fast']
ema_slow = self.config['technical_indicators']['ema_periods']['slow']
ema_signal = self.config['technical_indicators']['ema_periods']['signal']
ema12 = df['close'].ewm(span=ema_fast).mean()
ema26 = df['close'].ewm(span=ema_slow).mean()
macd = ema12 - ema26
signal = macd.ewm(span=ema_signal).mean()
# Calculate SMA
sma_fast = self.config['technical_indicators']['sma_periods']['fast']
sma_slow = self.config['technical_indicators']['sma_periods']['slow']
sma5 = df['close'].rolling(window=sma_fast).mean()
sma20 = df['close'].rolling(window=sma_slow).mean()
# Check conditions
latest_rsi = rsi.iloc[-1]
latest_macd = macd.iloc[-1]
latest_signal = signal.iloc[-1]
latest_sma5 = sma5.iloc[-1]
latest_sma20 = sma20.iloc[-1]
# RSI condition
rsi_ok = latest_rsi > self.rsi_threshold
# MACD condition
if self.macd_convergence and len(macd) >= 2 and len(signal) >= 2:
prev_macd = macd.iloc[-2]
prev_signal = signal.iloc[-2]
prev_diff = abs(prev_macd - prev_signal)
current_diff = abs(latest_macd - latest_signal)
macd_ok = current_diff < prev_diff # Lines are converging
else:
macd_ok = latest_macd > latest_signal # Traditional golden cross
# SMA condition
if self.sma5_above_sma20:
sma_ok = latest_sma5 > latest_sma20
else:
sma_ok = True # Skip SMA check if disabled
# Check if enough indicators are positive
conditions = [rsi_ok, macd_ok, sma_ok]
score = sum(conditions)
conn.close()
return score >= self.required_indicators
except Exception as e:
conn.close()
return False
def select_volatile_stocks(self, fear_start_date, fear_end_date):
"""Select stocks using technical indicators, then sort by volatility"""
qualified_stocks = []
for ticker in self.available_tickers:
if self.check_technical_indicators(ticker, fear_end_date):
vol = self.calculate_volatility(ticker, fear_end_date)
if vol > self.volatility_threshold:
qualified_stocks.append((ticker, vol))
# Sort by volatility and select top stocks
qualified_stocks.sort(key=lambda x: x[1], reverse=True)
top_stocks = [ticker for ticker, vol in qualified_stocks[:self.top_stocks_count]]
return top_stocks
def execute_trade(self, date, action, ticker=None, shares=None, price=None, value=None):
"""Execute and record a trade"""
fg_index = self.data.loc[date, 'fear_greed_index'] if date in self.data.index else None
total_assets = self.calculate_portfolio_value(date)
self.trades.append({
'date': date,
'action': action,
'ticker': ticker,
'shares': shares,
'price': price,
'value': value,
'fg_index': fg_index,
'cnn_fear_greed': fg_index, # Same as fg_index but with clearer name
'cash_after': self.cash,
'total_assets': total_assets,
'portfolio_state': self.state
})
def calculate_portfolio_value(self, date):
"""Calculate total portfolio value"""
total_value = self.cash
for ticker, shares in self.positions.items():
if ticker == 'QQQ':
price = self.data.loc[date, 'qqq_close']
else:
price = self.get_stock_price(ticker, date)
if price:
total_value += shares * price
return total_value
def check_stop_loss(self, date):
"""Check stop loss threshold"""
for ticker in list(self.positions.keys()):
if ticker == 'QQQ':
continue
current_price = self.get_stock_price(ticker, date)
if not current_price:
continue
# Find average buy price
buy_trades = [t for t in self.trades
if t['ticker'] == ticker and t['action'] in ['BUY_GRADUAL']]
if buy_trades:
total_cost = sum(t['price'] * t['shares'] for t in buy_trades)
total_shares = sum(t['shares'] for t in buy_trades)
avg_price = total_cost / total_shares
loss_pct = (current_price - avg_price) / avg_price
if loss_pct <= -self.stop_loss_threshold:
# Sell and buy QQQ
shares = self.positions[ticker]
value = shares * current_price
self.cash += value
del self.positions[ticker]
self.execute_trade(date, 'STOP_LOSS', ticker, shares, current_price, value)
# Buy QQQ with integer shares
qqq_price = self.data.loc[date, 'qqq_close']
qqq_shares = int(value / qqq_price)
if qqq_shares > 0:
actual_qqq_value = qqq_shares * qqq_price
self.positions['QQQ'] = self.positions.get('QQQ', 0) + qqq_shares
self.cash -= actual_qqq_value
self.execute_trade(date, 'BUY_QQQ_STOPLOSS', 'QQQ', qqq_shares, qqq_price, actual_qqq_value)
if self.config['output_settings']['show_console_output']:
print(f"{date.strftime('%Y-%m-%d')}: Stop loss triggered for {ticker}, loss: {loss_pct*100:.1f}%")
def start_transition(self, date, target_type, stocks=None):
"""Initialize transition plan"""
self.transition_plan = {'type': target_type, 'stocks': stocks}
if target_type == 'CASH':
self.transition_plan['positions_to_sell'] = {}
for ticker in self.positions:
self.transition_plan['positions_to_sell'][ticker] = self.positions[ticker]
elif target_type == 'QQQ':
cash_from_positions = 0
for ticker in self.positions:
if ticker != 'QQQ':
price = self.get_stock_price(ticker, date)
if price:
cash_from_positions += self.positions[ticker] * price
self.transition_cash_pool = self.cash + cash_from_positions
self.transition_plan['total_cash_to_invest'] = self.transition_cash_pool
self.transition_plan['positions_to_sell'] = {}
for ticker in self.positions:
if ticker != 'QQQ':
self.transition_plan['positions_to_sell'][ticker] = self.positions[ticker]
elif target_type == 'VOLATILE' and stocks:
cash_from_positions = 0
for ticker in self.positions:
if ticker != 'QQQ':
price = self.get_stock_price(ticker, date)
if price:
cash_from_positions += self.positions[ticker] * price
total_available_cash = self.cash + cash_from_positions
self.transition_plan['total_cash_to_invest'] = total_available_cash
def gradual_transition(self, date, target_type, stocks=None):
"""Handle gradual transitions with integer shares"""
step_size = 1.0 / self.transition_steps
if target_type == 'CASH':
for ticker in list(self.transition_plan.get('positions_to_sell', {})):
if ticker in self.positions:
total_shares_to_sell = self.transition_plan['positions_to_sell'][ticker]
shares_to_sell = int(total_shares_to_sell * step_size)
if shares_to_sell > 0 and shares_to_sell <= self.positions[ticker]:
price = self.get_stock_price(ticker, date)
if price:
value = shares_to_sell * price
self.cash += value
self.positions[ticker] -= shares_to_sell
if self.positions[ticker] <= 0:
del self.positions[ticker]
self.execute_trade(date, 'SELL_GRADUAL', ticker, shares_to_sell, price, value)
elif target_type == 'VOLATILE' and stocks:
total_cash = self.transition_plan.get('total_cash_to_invest', 0)
cash_this_step = total_cash * step_size
if cash_this_step > 0 and self.cash >= cash_this_step:
current_step_index = min(self.current_step, len(stocks) - 1)
ticker = stocks[current_step_index]
price = self.get_stock_price(ticker, date)
if price and cash_this_step > 0:
shares = int(cash_this_step / price)
if shares > 0:
actual_value = shares * price
self.positions[ticker] = self.positions.get(ticker, 0) + shares
self.cash -= actual_value
self.execute_trade(date, 'BUY_GRADUAL', ticker, shares, price, actual_value)
elif target_type == 'QQQ':
# Sell positions gradually
for ticker in list(self.transition_plan.get('positions_to_sell', {})):
if ticker in self.positions:
total_shares_to_sell = self.transition_plan['positions_to_sell'][ticker]
shares_to_sell = int(total_shares_to_sell * step_size)
if shares_to_sell > 0 and shares_to_sell <= self.positions[ticker]:
price = self.get_stock_price(ticker, date)
if price:
value = shares_to_sell * price
self.cash += value
self.positions[ticker] -= shares_to_sell
if self.positions[ticker] <= 0:
del self.positions[ticker]
self.execute_trade(date, 'SELL_GRADUAL', ticker, shares_to_sell, price, value)
# Buy QQQ
total_cash = self.transition_plan.get('total_cash_to_invest', 0)
cash_this_step = total_cash * step_size
if cash_this_step > 0 and self.cash >= cash_this_step:
qqq_price = self.data.loc[date, 'qqq_close']
qqq_shares = int(cash_this_step / qqq_price)
if qqq_shares > 0:
actual_value = qqq_shares * qqq_price
self.positions['QQQ'] = self.positions.get('QQQ', 0) + qqq_shares
self.cash -= actual_value
self.execute_trade(date, 'BUY_GRADUAL', 'QQQ', qqq_shares, qqq_price, actual_value)
def run_backtest(self):
"""Run the strategy backtest"""
if self.config['output_settings']['show_console_output']:
print("🚀 开始运行Enhanced Emotional Damage Strategy...")
print("")
self.get_data()
# Start with 100% QQQ
first_date = self.data.index[0]
qqq_price = self.data.loc[first_date, 'qqq_close']
qqq_shares = int(self.initial_capital / qqq_price)
self.positions['QQQ'] = qqq_shares
self.cash = self.initial_capital - (qqq_shares * qqq_price)
fear_start_date = None
for date, row in self.data.iterrows():
fg_index = row['fear_greed_index']
# Check stop loss
self.check_stop_loss(date)
if self.state == 'QQQ_HOLD':
if fg_index < self.fear_threshold:
fear_start_date = date
self.state = 'FEAR_TRANSITION'
self.current_step = 0
self.start_transition(date, 'CASH')
if self.config['output_settings']['show_console_output']:
print(f"{date.strftime('%Y-%m-%d')}: Fear threshold hit ({fg_index:.1f}), starting transition to cash")
elif self.state == 'FEAR_TRANSITION':
self.gradual_transition(date, 'CASH')
self.current_step += 1
if self.current_step >= self.transition_steps:
self.state = 'CASH_WAIT'
if self.config['output_settings']['show_console_output']:
print(f"{date.strftime('%Y-%m-%d')}: Transition to cash complete")
elif self.state == 'CASH_WAIT':
if fg_index >= self.fear_threshold and fear_start_date:
top_stocks = self.select_volatile_stocks(fear_start_date, date)
if top_stocks:
self.state = 'GREED_TRANSITION'
self.current_step = 0
self.transition_stocks = top_stocks
self.start_transition(date, 'VOLATILE', top_stocks)
if self.config['output_settings']['show_console_output']:
print(f"{date.strftime('%Y-%m-%d')}: Fear recovered, starting transition to volatile stocks: {top_stocks}")
else:
self.state = 'QQQ_TRANSITION'
self.current_step = 0
self.start_transition(date, 'QQQ')
if self.config['output_settings']['show_console_output']:
print(f"{date.strftime('%Y-%m-%d')}: Fear recovered, no suitable stocks, returning to QQQ")
elif self.state == 'GREED_TRANSITION':
self.gradual_transition(date, 'VOLATILE', self.transition_stocks)
self.current_step += 1
if self.current_step >= self.transition_steps:
self.state = 'VOLATILE_STOCKS'
if self.config['output_settings']['show_console_output']:
print(f"{date.strftime('%Y-%m-%d')}: Transition to volatile stocks complete")
elif self.state == 'VOLATILE_STOCKS':
if fg_index > self.greed_threshold:
self.state = 'QQQ_TRANSITION'
self.current_step = 0
self.start_transition(date, 'QQQ')
if self.config['output_settings']['show_console_output']:
print(f"{date.strftime('%Y-%m-%d')}: Greed threshold hit ({fg_index:.1f}), starting transition to QQQ")
elif self.state == 'QQQ_TRANSITION':
self.gradual_transition(date, 'QQQ')
self.current_step += 1
if self.current_step >= self.transition_steps:
self.state = 'QQQ_HOLD'
if self.config['output_settings']['show_console_output']:
print(f"{date.strftime('%Y-%m-%d')}: Transition to QQQ complete")
# Record portfolio value
portfolio_value = self.calculate_portfolio_value(date)
self.portfolio_value.append({
'date': date,
'value': portfolio_value,
'state': self.state,
'fg_index': fg_index
})
if self.config['output_settings']['show_console_output']:
print("")
print(f"✅ 回测完成! 总交易数: {len(self.trades)}")
print("")
def calculate_performance_metrics(self, returns):
"""Calculate performance metrics"""
total_return = (returns.iloc[-1] / returns.iloc[0] - 1) * 100
annual_return = ((returns.iloc[-1] / returns.iloc[0]) ** (252 / len(returns)) - 1) * 100
# Calculate max drawdown
peak = returns.expanding().max()
drawdown = (returns - peak) / peak
max_drawdown = drawdown.min() * 100
# Calculate Sharpe ratio
daily_returns = returns.pct_change().dropna()
sharpe_ratio = np.sqrt(252) * daily_returns.mean() / daily_returns.std()
# Annual returns by year
annual_rets = {}
for year in returns.index.year.unique():
year_data = returns[returns.index.year == year]
if len(year_data) > 1:
year_return = (year_data.iloc[-1] / year_data.iloc[0] - 1) * 100
annual_rets[year] = year_return
return {
'total_return': total_return,
'annual_return': annual_return,
'max_drawdown': max_drawdown,
'sharpe_ratio': sharpe_ratio,
'annual_returns': annual_rets
}
def generate_reports(strategy):
"""Generate all reports based on config settings"""
config = strategy.config
# Create output directories
reports_dir = config['paths']['reports_dir']
results_dir = config['paths']['results_dir']
os.makedirs(reports_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# Save strategy object if requested
if config['output_settings']['save_strategy_object']:
strategy_file = os.path.join(results_dir, f'enhanced_strategy_{timestamp}.pkl')
with open(strategy_file, 'wb') as f:
pickle.dump(strategy, f)
print(f"📦 策略对象已保存: {strategy_file}")
# Generate CSV reports
if config['output_settings']['generate_csv'] or config['output_settings']['generate_detailed_trades']:
trades_df = pd.DataFrame(strategy.trades)
if config['output_settings']['generate_csv']:
csv_file = os.path.join(reports_dir, f'enhanced_trades_{timestamp}.csv')
trades_df.to_csv(csv_file, index=False)
print(f"📊 交易CSV已保存: {csv_file}")
if config['output_settings']['generate_detailed_trades']:
detailed_file = os.path.join(reports_dir, f'detailed_trades_{timestamp}.txt')
with open(detailed_file, 'w', encoding='utf-8') as f:
f.write("Enhanced Emotional Damage Strategy - Detailed Trades Report\n")
f.write("=" * 120 + "\n\n")
# Summary
f.write(f"📊 交易摘要:\n")
f.write(f"总交易数: {len(strategy.trades)}\n")
f.write(f"交易时间: {trades_df['date'].min().strftime('%Y-%m-%d')}{trades_df['date'].max().strftime('%Y-%m-%d')}\n")
# Trade types
action_counts = trades_df['action'].value_counts()
f.write(f"\n交易类型统计:\n")
for action, count in action_counts.items():
f.write(f" {action}: {count}\n")
# Detailed trades
f.write(f"\n📋 详细交易记录:\n")
f.write("=" * 150 + "\n")
f.write(f"{'No':>3s} {'Date':>10s} {'Action':>15s} {'Ticker':>5s} {'Shares':>8s} {'Price':>7s} {'Value':>12s} {'F&G':>4s} {'Cash':>12s} {'Assets':>12s} {'State':>15s}\n")
f.write("=" * 150 + "\n")
for i, trade in enumerate(strategy.trades, 1):
fg_str = f"{trade.get('fg_index', 0):.0f}" if trade.get('fg_index') else "N/A"
cash_str = f"${trade.get('cash_after', 0):,.0f}" if trade.get('cash_after') else "N/A"
assets_str = f"${trade.get('total_assets', 0):,.0f}" if trade.get('total_assets') else "N/A"
state_str = trade.get('portfolio_state', 'N/A')
f.write(f"{i:3d} {trade['date'].strftime('%Y-%m-%d'):>10s} {trade['action']:>15s} {trade['ticker']:>5s} "
f"{trade['shares']:>8.0f} ${trade['price']:>7.2f} ${trade['value']:>12,.0f} {fg_str:>4s} {cash_str:>12s} {assets_str:>12s} {state_str:>15s}\n")
print(f"📝 详细交易报告已保存: {detailed_file}")
# Generate PDF report
if config['output_settings']['generate_pdf']:
try:
generate_enhanced_pdf_report(strategy, reports_dir, timestamp)
except Exception as e:
print(f"⚠️ PDF生成失败: {e}")
import traceback
traceback.print_exc()
print("\n🎉 所有报告生成完成!")
def generate_enhanced_pdf_report(strategy, reports_dir, timestamp):
"""Generate comprehensive PDF report with enhanced layout and proper spacing"""
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
import sqlite3
import os
# Prepare data
portfolio_df = pd.DataFrame(strategy.portfolio_value)
portfolio_df.set_index('date', inplace=True)
trades_df = pd.DataFrame(strategy.trades)
# Get benchmark data
db_path = strategy.config['paths']['database_path']
conn = sqlite3.connect(db_path)
qqq_data = pd.read_sql_query('''
SELECT date, close as qqq_close
FROM qqq
ORDER BY date
''', conn)
qqq_data['date'] = pd.to_datetime(qqq_data['date'])
qqq_data.set_index('date', inplace=True)
spy_data = pd.read_sql_query('''
SELECT date, spy_close
FROM fear_greed_data
ORDER BY date
''', conn)
spy_data['date'] = pd.to_datetime(spy_data['date'])
spy_data.set_index('date', inplace=True)
conn.close()
# Merge and align data
benchmark_data = pd.merge(qqq_data, spy_data, left_index=True, right_index=True, how='inner')
common_dates = portfolio_df.index.intersection(benchmark_data.index)
portfolio_df = portfolio_df.loc[common_dates]
benchmark_data = benchmark_data.loc[common_dates]
# Normalize benchmarks
start_value = strategy.initial_capital
benchmark_data['qqq_value'] = start_value * (benchmark_data['qqq_close'] / benchmark_data['qqq_close'].iloc[0])
benchmark_data['spy_value'] = start_value * (benchmark_data['spy_close'] / benchmark_data['spy_close'].iloc[0])
# Calculate metrics
strategy_metrics = strategy.calculate_performance_metrics(portfolio_df['value'])
qqq_metrics = strategy.calculate_performance_metrics(benchmark_data['qqq_value'])
spy_metrics = strategy.calculate_performance_metrics(benchmark_data['spy_value'])
# Find max drawdown year
def find_max_drawdown_year(returns):
peak = returns.expanding().max()
drawdown = (returns - peak) / peak
max_dd_date = drawdown.idxmin()
return max_dd_date.year
strategy_dd_year = find_max_drawdown_year(portfolio_df['value'])
qqq_dd_year = find_max_drawdown_year(benchmark_data['qqq_value'])
spy_dd_year = find_max_drawdown_year(benchmark_data['spy_value'])
# Create PDF with multiple pages
pdf_file = os.path.join(reports_dir, f'enhanced_strategy_report_{timestamp}.pdf')
with PdfPages(pdf_file) as pdf:
# Set global font to support text and better spacing
plt.rcParams['font.size'] = 10
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['axes.labelsize'] = 10
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8
plt.rcParams['legend.fontsize'] = 8
plt.rcParams['figure.titlesize'] = 14
# Page 1: Performance Comparison (Full Width)
fig1 = plt.figure(figsize=(8.5, 11))
fig1.suptitle('Enhanced Emotional Damage Strategy Report', fontsize=16, fontweight='bold', y=0.96)
# 1. Total Return Curve (Full width)
ax1 = plt.subplot(4, 1, 1)
ax1.plot(portfolio_df.index, portfolio_df['value'] / 1000,
label='Enhanced Strategy', linewidth=2, color='red')
ax1.plot(benchmark_data.index, benchmark_data['qqq_value'] / 1000,
label='QQQ', linewidth=2, color='blue')
ax1.plot(benchmark_data.index, benchmark_data['spy_value'] / 1000,
label='SPY', linewidth=2, color='green')
ax1.set_title('Portfolio Performance Comparison', fontsize=14, fontweight='bold', pad=25)
ax1.set_ylabel('Portfolio Value ($K)', fontsize=11)
ax1.legend(fontsize=10, loc='upper left')
ax1.grid(True, alpha=0.3)
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
for label in ax1.get_xticklabels():
label.set_rotation(45)
# 2. Performance Metrics Table (Full width)
ax2 = plt.subplot(4, 1, 2)
ax2.axis('off')
metrics_data = [
['Metric', 'Enhanced Strategy', 'QQQ', 'SPY'],
['Total Return', f"{strategy_metrics['total_return']:.1f}%",
f"{qqq_metrics['total_return']:.1f}%", f"{spy_metrics['total_return']:.1f}%"],
['Annual Return', f"{strategy_metrics['annual_return']:.1f}%",
f"{qqq_metrics['annual_return']:.1f}%", f"{spy_metrics['annual_return']:.1f}%"],
['Max Drawdown', f"{strategy_metrics['max_drawdown']:.1f}%",
f"{qqq_metrics['max_drawdown']:.1f}%", f"{spy_metrics['max_drawdown']:.1f}%"],
['Max DD Year', str(strategy_dd_year), str(qqq_dd_year), str(spy_dd_year)],
['Sharpe Ratio', f"{strategy_metrics['sharpe_ratio']:.2f}",
f"{qqq_metrics['sharpe_ratio']:.2f}", f"{spy_metrics['sharpe_ratio']:.2f}"],
['Total Trades', f"{len(strategy.trades)}", 'N/A', 'N/A']
]
table = ax2.table(cellText=metrics_data, cellLoc='center', loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1.2, 2.0) # More height for readability
for i in range(len(metrics_data[0])):
table[(0, i)].set_facecolor('#40466e')
table[(0, i)].set_text_props(weight='bold', color='white')
ax2.set_title('Performance Metrics Comparison', fontsize=14, fontweight='bold', pad=25)
# 3. Strategy State Timeline (Full width)
ax3 = plt.subplot(4, 1, 3)
# Create state mapping and colors
state_colors = {
'QQQ_HOLD': 'blue',
'FEAR_TRANSITION': 'orange',
'CASH_WAIT': 'gray',
'GREED_TRANSITION': 'yellow',
'VOLATILE_STOCKS': 'red',
'QQQ_TRANSITION': 'green'
}
# Plot portfolio value with state colors
for state, color in state_colors.items():
state_data = portfolio_df[portfolio_df['state'] == state]
if not state_data.empty:
ax3.scatter(state_data.index, state_data['value'] / 1000,
c=color, s=2, alpha=0.8, label=state)
# Add stop-loss markers
stop_loss_trades = trades_df[trades_df['action'] == 'STOP_LOSS']
if not stop_loss_trades.empty:
for _, trade in stop_loss_trades.iterrows():
ax3.axvline(x=trade['date'], color='red', linestyle='--', alpha=0.8, linewidth=1)
ax3.set_title('Strategy State Timeline with Stop-Loss Events', fontsize=14, fontweight='bold', pad=25)
ax3.set_ylabel('Total Assets ($K)', fontsize=11)
ax3.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8)
ax3.grid(True, alpha=0.3)
ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
for label in ax3.get_xticklabels():
label.set_rotation(45)
# 4. Annual Returns Comparison (Full width)
ax4 = plt.subplot(4, 1, 4)
years = list(strategy_metrics['annual_returns'].keys())
enhanced_returns = list(strategy_metrics['annual_returns'].values())
qqq_returns = [qqq_metrics['annual_returns'].get(year, 0) for year in years]
spy_returns = [spy_metrics['annual_returns'].get(year, 0) for year in years]
x = np.arange(len(years))
width = 0.25
ax4.bar(x - width, enhanced_returns, width, label='Enhanced Strategy', color='red', alpha=0.8)
ax4.bar(x, qqq_returns, width, label='QQQ', color='blue', alpha=0.8)
ax4.bar(x + width, spy_returns, width, label='SPY', color='green', alpha=0.8)
ax4.set_title('Annual Returns Comparison by Year', fontsize=14, fontweight='bold', pad=25)
ax4.set_ylabel('Annual Return (%)', fontsize=11)
ax4.set_xlabel('Year', fontsize=11)
ax4.set_xticks(x)
ax4.set_xticklabels(years)
for label in ax4.get_xticklabels():
label.set_rotation(90) # Vertical text for years
ax4.legend(fontsize=10)
ax4.grid(True, alpha=0.3, axis='y')
plt.subplots_adjust(left=0.1, right=0.85, top=0.90, bottom=0.08, hspace=0.6)
pdf.savefig(fig1, bbox_inches='tight', dpi=150)
plt.close()
# Page 2: Detailed Analysis
fig2 = plt.figure(figsize=(8.5, 11))
fig2.suptitle('Detailed Trading and Risk Analysis', fontsize=16, fontweight='bold', y=0.95)
# 5. Stop-Loss Analysis (Full width)
ax5 = plt.subplot(3, 1, 1)
if not stop_loss_trades.empty:
stop_loss_trades_copy = stop_loss_trades.copy()
stop_loss_trades_copy['year'] = stop_loss_trades_copy['date'].dt.year
stop_loss_by_year = stop_loss_trades_copy.groupby('year').size()
ax5.bar(stop_loss_by_year.index, stop_loss_by_year.values, color='red', alpha=0.8, width=0.6)
ax5.set_title('Stop-Loss Triggers by Year', fontsize=14, fontweight='bold', pad=30)
ax5.set_ylabel('Number of Stop-Loss Events', fontsize=11)
ax5.set_xlabel('Year', fontsize=11)
ax5.grid(True, alpha=0.3, axis='y')
for label in ax5.get_xticklabels():
label.set_rotation(45)
else:
ax5.text(0.5, 0.5, 'No Stop-Loss Events Triggered', ha='center', va='center',
transform=ax5.transAxes, fontsize=14, fontweight='bold')
ax5.set_title('Stop-Loss Analysis', fontsize=14, fontweight='bold', pad=30)
# 6. Trade Frequency Analysis (Full width)
ax6 = plt.subplot(3, 1, 2)
trades_df_copy = trades_df.copy()
trades_df_copy['year'] = trades_df_copy['date'].dt.year
trade_frequency = trades_df_copy.groupby('year').size()
ax6.bar(trade_frequency.index, trade_frequency.values, color='purple', alpha=0.8, width=0.6)
ax6.set_title('Trading Activity by Year', fontsize=14, fontweight='bold', pad=30)
ax6.set_ylabel('Number of Trades', fontsize=11)
ax6.set_xlabel('Year', fontsize=11)
ax6.grid(True, alpha=0.3, axis='y')
for label in ax6.get_xticklabels():
label.set_rotation(45)
# 7. Fear & Greed Index with Trading Signals (Full width)
ax7 = plt.subplot(3, 1, 3)
# Plot Fear & Greed Index
fg_data = portfolio_df['fg_index'].dropna()
ax7.plot(fg_data.index, fg_data.values, color='purple', alpha=0.8, linewidth=1.5)
ax7.axhline(y=25, color='red', linestyle='--', alpha=0.7, linewidth=2, label='Fear Threshold (25)')
ax7.axhline(y=75, color='green', linestyle='--', alpha=0.7, linewidth=2, label='Greed Threshold (75)')
ax7.fill_between(fg_data.index, 0, 25, alpha=0.2, color='red', label='Fear Zone')
ax7.fill_between(fg_data.index, 75, 100, alpha=0.2, color='green', label='Greed Zone')
# Add trade markers
buy_trades = trades_df[trades_df['action'].str.contains('BUY')]
sell_trades = trades_df[trades_df['action'].str.contains('SELL')]
if not buy_trades.empty:
ax7.scatter(buy_trades['date'], buy_trades['fg_index'],
color='darkgreen', s=15, alpha=0.8, marker='^', label='Buy Signals', zorder=5)
if not sell_trades.empty:
ax7.scatter(sell_trades['date'], sell_trades['fg_index'],
color='darkred', s=15, alpha=0.8, marker='v', label='Sell Signals', zorder=5)
ax7.set_title('Fear & Greed Index with Trading Signals', fontsize=14, fontweight='bold', pad=30)
ax7.set_ylabel('CNN Fear & Greed Index', fontsize=11)
ax7.set_xlabel('Date', fontsize=11)
ax7.set_ylim(0, 100)
ax7.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8)
ax7.grid(True, alpha=0.3)
ax7.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
for label in ax7.get_xticklabels():
label.set_rotation(45)
plt.subplots_adjust(left=0.1, right=0.85, top=0.88, bottom=0.10, hspace=1.0)
pdf.savefig(fig2, bbox_inches='tight', dpi=150)
plt.close()
# Test PDF readability by checking file size and trying to read it
try:
file_size = os.path.getsize(pdf_file)
if file_size < 50000: # Less than 50KB might indicate issues
print(f"⚠️ Warning: PDF file size seems small ({file_size} bytes)")
else:
print(f"📈 PDF报告已保存: {pdf_file} (Size: {file_size:,} bytes)")
# Quick validation - ensure we can open the file
with open(pdf_file, 'rb') as f:
header = f.read(10)
if not header.startswith(b'%PDF'):
print(f"⚠️ Warning: Generated file may not be a valid PDF")
else:
print(f"✅ PDF file validation passed")
except Exception as e:
print(f"⚠️ Error validating PDF: {e}")
print(f"📈 PDF报告已保存: {pdf_file}")
def main():
"""Main function to run strategy with config"""
config_file = 'config.json'
if not os.path.exists(config_file):
print(f"❌ 配置文件未找到: {config_file}")
print("请确保config.json文件存在于当前目录")
return
try:
# Initialize and run strategy
strategy = ConfigurableEmotionalDamageStrategy(config_file)
strategy.run_backtest()
# Generate reports
generate_reports(strategy)
except Exception as e:
print(f"❌ 运行失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()