#!/usr/bin/env python3 import json import sqlite3 import pandas as pd import numpy as np from datetime import datetime, timedelta import warnings import os import pickle import matplotlib.pyplot as plt import seaborn as sns warnings.filterwarnings('ignore') class ConfigurableEmotionalDamageStrategy: def __init__(self, config_path='config.json'): """Initialize strategy with config file""" self.load_config(config_path) self.cash = self.config['strategy_params']['initial_capital'] self.positions = {} self.portfolio_value = [] self.trades = [] # State management self.state = 'QQQ_HOLD' self.current_step = 0 self.target_allocation = {} self.last_fear_date = None # For gradual transitions self.transition_plan = {} self.transition_cash_pool = 0 def load_config(self, config_path): """Load configuration from JSON file""" with open(config_path, 'r') as f: self.config = json.load(f) # Set strategy parameters as attributes for easy access params = self.config['strategy_params'] self.initial_capital = params['initial_capital'] self.fear_threshold = params['fear_threshold'] self.greed_threshold = params['greed_threshold'] self.stop_loss_threshold = params['stop_loss_threshold'] self.top_stocks_count = params['top_stocks_count'] self.volatility_threshold = params['volatility_threshold'] self.volatility_calculation_days = params['volatility_calculation_days'] self.transition_steps = params['transition_steps'] # Technical indicators tech = self.config['technical_indicators'] self.rsi_threshold = tech['rsi_threshold'] self.required_indicators = tech['required_indicators'] self.sma5_above_sma20 = tech['sma5_above_sma20'] self.macd_convergence = tech['macd_convergence'] if self.config['output_settings']['show_console_output']: print(f"✅ 配置已加载:") print(f" 初始资金: ${self.initial_capital:,}") print(f" 恐慌阈值: {self.fear_threshold}") print(f" 贪婪阈值: {self.greed_threshold}") print(f" 止损阈值: {self.stop_loss_threshold*100}%") print(f" 转换步数: {self.transition_steps}") print(f" 选股数量: {self.top_stocks_count}") print("") def get_data(self): """Load Fear & Greed Index and stock data""" db_path = self.config['paths']['database_path'] if self.config['output_settings']['show_console_output']: print(f"连接数据库: {db_path}") conn = sqlite3.connect(db_path) # Get Fear & Greed Index fg_data = pd.read_sql_query(''' SELECT date, fear_greed_index FROM fear_greed_index ORDER BY date ''', conn) fg_data['date'] = pd.to_datetime(fg_data['date']) fg_data.set_index('date', inplace=True) # Get real QQQ price data qqq_data = pd.read_sql_query(''' SELECT date, close as qqq_close FROM qqq ORDER BY date ''', conn) qqq_data['date'] = pd.to_datetime(qqq_data['date']) qqq_data.set_index('date', inplace=True) # Get available tickers min_records = self.config['data_settings']['min_ticker_records'] cursor = conn.cursor() cursor.execute(f'SELECT ticker FROM ticker_list WHERE records > {min_records}') self.available_tickers = [row[0] for row in cursor.fetchall()] conn.close() # Merge data self.data = pd.merge(fg_data, qqq_data, left_index=True, right_index=True, how='inner') # Apply date filters if specified if self.config['data_settings']['start_date']: start_date = pd.to_datetime(self.config['data_settings']['start_date']) self.data = self.data[self.data.index >= start_date] if self.config['data_settings']['end_date']: end_date = pd.to_datetime(self.config['data_settings']['end_date']) self.data = self.data[self.data.index <= end_date] self.data.sort_index(inplace=True) if self.config['output_settings']['show_console_output']: print(f"数据加载完成: {self.data.index.min().strftime('%Y-%m-%d')} 到 {self.data.index.max().strftime('%Y-%m-%d')}") print(f"可用股票数量: {len(self.available_tickers)}") print("") def get_stock_price(self, ticker, date): """Get stock price for a specific ticker and date""" db_path = self.config['paths']['database_path'] conn = sqlite3.connect(db_path) query = f''' SELECT close FROM {ticker.lower()} WHERE date <= ? ORDER BY date DESC LIMIT 1 ''' cursor = conn.cursor() cursor.execute(query, (date.strftime('%Y-%m-%d'),)) result = cursor.fetchone() conn.close() return result[0] if result else None def calculate_volatility(self, ticker, current_date): """Calculate historical volatility over configured period""" db_path = self.config['paths']['database_path'] conn = sqlite3.connect(db_path) try: start_date = current_date - timedelta(days=self.volatility_calculation_days) query = f''' SELECT date, close FROM {ticker.lower()} WHERE date >= ? AND date <= ? ORDER BY date ''' df = pd.read_sql_query(query, conn, params=( start_date.strftime('%Y-%m-%d'), current_date.strftime('%Y-%m-%d') )) if len(df) > 10: df['returns'] = df['close'].pct_change() volatility = df['returns'].std() * np.sqrt(252) conn.close() return volatility except Exception as e: pass conn.close() return 0 def check_technical_indicators(self, ticker, date): """Check RSI, MACD, and SMA technical indicators""" db_path = self.config['paths']['database_path'] conn = sqlite3.connect(db_path) try: query = f''' SELECT date, close FROM {ticker.lower()} WHERE date <= ? ORDER BY date DESC LIMIT 50 ''' df = pd.read_sql_query(query, conn, params=(date.strftime('%Y-%m-%d'),)) if len(df) < 20: conn.close() return False df = df.sort_values('date') df.reset_index(drop=True, inplace=True) # Calculate RSI rsi_period = self.config['technical_indicators']['rsi_period'] delta = df['close'].diff() gain = (delta.where(delta > 0, 0)).rolling(window=rsi_period).mean() loss = (-delta.where(delta < 0, 0)).rolling(window=rsi_period).mean() rs = gain / loss rsi = 100 - (100 / (1 + rs)) # Calculate MACD ema_fast = self.config['technical_indicators']['ema_periods']['fast'] ema_slow = self.config['technical_indicators']['ema_periods']['slow'] ema_signal = self.config['technical_indicators']['ema_periods']['signal'] ema12 = df['close'].ewm(span=ema_fast).mean() ema26 = df['close'].ewm(span=ema_slow).mean() macd = ema12 - ema26 signal = macd.ewm(span=ema_signal).mean() # Calculate SMA sma_fast = self.config['technical_indicators']['sma_periods']['fast'] sma_slow = self.config['technical_indicators']['sma_periods']['slow'] sma5 = df['close'].rolling(window=sma_fast).mean() sma20 = df['close'].rolling(window=sma_slow).mean() # Check conditions latest_rsi = rsi.iloc[-1] latest_macd = macd.iloc[-1] latest_signal = signal.iloc[-1] latest_sma5 = sma5.iloc[-1] latest_sma20 = sma20.iloc[-1] # RSI condition rsi_ok = latest_rsi > self.rsi_threshold # MACD condition if self.macd_convergence and len(macd) >= 2 and len(signal) >= 2: prev_macd = macd.iloc[-2] prev_signal = signal.iloc[-2] prev_diff = abs(prev_macd - prev_signal) current_diff = abs(latest_macd - latest_signal) macd_ok = current_diff < prev_diff # Lines are converging else: macd_ok = latest_macd > latest_signal # Traditional golden cross # SMA condition if self.sma5_above_sma20: sma_ok = latest_sma5 > latest_sma20 else: sma_ok = True # Skip SMA check if disabled # Check if enough indicators are positive conditions = [rsi_ok, macd_ok, sma_ok] score = sum(conditions) conn.close() return score >= self.required_indicators except Exception as e: conn.close() return False def select_volatile_stocks(self, fear_start_date, fear_end_date): """Select stocks using technical indicators, then sort by volatility""" qualified_stocks = [] for ticker in self.available_tickers: if self.check_technical_indicators(ticker, fear_end_date): vol = self.calculate_volatility(ticker, fear_end_date) if vol > self.volatility_threshold: qualified_stocks.append((ticker, vol)) # Sort by volatility and select top stocks qualified_stocks.sort(key=lambda x: x[1], reverse=True) top_stocks = [ticker for ticker, vol in qualified_stocks[:self.top_stocks_count]] return top_stocks def execute_trade(self, date, action, ticker=None, shares=None, price=None, value=None): """Execute and record a trade""" fg_index = self.data.loc[date, 'fear_greed_index'] if date in self.data.index else None total_assets = self.calculate_portfolio_value(date) self.trades.append({ 'date': date, 'action': action, 'ticker': ticker, 'shares': shares, 'price': price, 'value': value, 'fg_index': fg_index, 'cnn_fear_greed': fg_index, # Same as fg_index but with clearer name 'cash_after': self.cash, 'total_assets': total_assets, 'portfolio_state': self.state }) def calculate_portfolio_value(self, date): """Calculate total portfolio value""" total_value = self.cash for ticker, shares in self.positions.items(): if ticker == 'QQQ': price = self.data.loc[date, 'qqq_close'] else: price = self.get_stock_price(ticker, date) if price: total_value += shares * price return total_value def check_stop_loss(self, date): """Check stop loss threshold""" for ticker in list(self.positions.keys()): if ticker == 'QQQ': continue current_price = self.get_stock_price(ticker, date) if not current_price: continue # Find average buy price buy_trades = [t for t in self.trades if t['ticker'] == ticker and t['action'] in ['BUY_GRADUAL']] if buy_trades: total_cost = sum(t['price'] * t['shares'] for t in buy_trades) total_shares = sum(t['shares'] for t in buy_trades) avg_price = total_cost / total_shares loss_pct = (current_price - avg_price) / avg_price if loss_pct <= -self.stop_loss_threshold: # Sell and buy QQQ shares = self.positions[ticker] value = shares * current_price self.cash += value del self.positions[ticker] self.execute_trade(date, 'STOP_LOSS', ticker, shares, current_price, value) # Buy QQQ with integer shares qqq_price = self.data.loc[date, 'qqq_close'] qqq_shares = int(value / qqq_price) if qqq_shares > 0: actual_qqq_value = qqq_shares * qqq_price self.positions['QQQ'] = self.positions.get('QQQ', 0) + qqq_shares self.cash -= actual_qqq_value self.execute_trade(date, 'BUY_QQQ_STOPLOSS', 'QQQ', qqq_shares, qqq_price, actual_qqq_value) if self.config['output_settings']['show_console_output']: print(f"{date.strftime('%Y-%m-%d')}: Stop loss triggered for {ticker}, loss: {loss_pct*100:.1f}%") def start_transition(self, date, target_type, stocks=None): """Initialize transition plan""" self.transition_plan = {'type': target_type, 'stocks': stocks} if target_type == 'CASH': self.transition_plan['positions_to_sell'] = {} for ticker in self.positions: self.transition_plan['positions_to_sell'][ticker] = self.positions[ticker] elif target_type == 'QQQ': cash_from_positions = 0 for ticker in self.positions: if ticker != 'QQQ': price = self.get_stock_price(ticker, date) if price: cash_from_positions += self.positions[ticker] * price self.transition_cash_pool = self.cash + cash_from_positions self.transition_plan['total_cash_to_invest'] = self.transition_cash_pool self.transition_plan['positions_to_sell'] = {} for ticker in self.positions: if ticker != 'QQQ': self.transition_plan['positions_to_sell'][ticker] = self.positions[ticker] elif target_type == 'VOLATILE' and stocks: cash_from_positions = 0 for ticker in self.positions: if ticker != 'QQQ': price = self.get_stock_price(ticker, date) if price: cash_from_positions += self.positions[ticker] * price total_available_cash = self.cash + cash_from_positions self.transition_plan['total_cash_to_invest'] = total_available_cash def gradual_transition(self, date, target_type, stocks=None): """Handle gradual transitions with integer shares""" step_size = 1.0 / self.transition_steps if target_type == 'CASH': for ticker in list(self.transition_plan.get('positions_to_sell', {})): if ticker in self.positions: total_shares_to_sell = self.transition_plan['positions_to_sell'][ticker] shares_to_sell = int(total_shares_to_sell * step_size) if shares_to_sell > 0 and shares_to_sell <= self.positions[ticker]: price = self.get_stock_price(ticker, date) if price: value = shares_to_sell * price self.cash += value self.positions[ticker] -= shares_to_sell if self.positions[ticker] <= 0: del self.positions[ticker] self.execute_trade(date, 'SELL_GRADUAL', ticker, shares_to_sell, price, value) elif target_type == 'VOLATILE' and stocks: total_cash = self.transition_plan.get('total_cash_to_invest', 0) cash_this_step = total_cash * step_size if cash_this_step > 0 and self.cash >= cash_this_step: current_step_index = min(self.current_step, len(stocks) - 1) ticker = stocks[current_step_index] price = self.get_stock_price(ticker, date) if price and cash_this_step > 0: shares = int(cash_this_step / price) if shares > 0: actual_value = shares * price self.positions[ticker] = self.positions.get(ticker, 0) + shares self.cash -= actual_value self.execute_trade(date, 'BUY_GRADUAL', ticker, shares, price, actual_value) elif target_type == 'QQQ': # Sell positions gradually for ticker in list(self.transition_plan.get('positions_to_sell', {})): if ticker in self.positions: total_shares_to_sell = self.transition_plan['positions_to_sell'][ticker] shares_to_sell = int(total_shares_to_sell * step_size) if shares_to_sell > 0 and shares_to_sell <= self.positions[ticker]: price = self.get_stock_price(ticker, date) if price: value = shares_to_sell * price self.cash += value self.positions[ticker] -= shares_to_sell if self.positions[ticker] <= 0: del self.positions[ticker] self.execute_trade(date, 'SELL_GRADUAL', ticker, shares_to_sell, price, value) # Buy QQQ total_cash = self.transition_plan.get('total_cash_to_invest', 0) cash_this_step = total_cash * step_size if cash_this_step > 0 and self.cash >= cash_this_step: qqq_price = self.data.loc[date, 'qqq_close'] qqq_shares = int(cash_this_step / qqq_price) if qqq_shares > 0: actual_value = qqq_shares * qqq_price self.positions['QQQ'] = self.positions.get('QQQ', 0) + qqq_shares self.cash -= actual_value self.execute_trade(date, 'BUY_GRADUAL', 'QQQ', qqq_shares, qqq_price, actual_value) def run_backtest(self): """Run the strategy backtest""" if self.config['output_settings']['show_console_output']: print("🚀 开始运行Enhanced Emotional Damage Strategy...") print("") self.get_data() # Start with 100% QQQ first_date = self.data.index[0] qqq_price = self.data.loc[first_date, 'qqq_close'] qqq_shares = int(self.initial_capital / qqq_price) self.positions['QQQ'] = qqq_shares self.cash = self.initial_capital - (qqq_shares * qqq_price) fear_start_date = None for date, row in self.data.iterrows(): fg_index = row['fear_greed_index'] # Check stop loss self.check_stop_loss(date) if self.state == 'QQQ_HOLD': if fg_index < self.fear_threshold: fear_start_date = date self.state = 'FEAR_TRANSITION' self.current_step = 0 self.start_transition(date, 'CASH') if self.config['output_settings']['show_console_output']: print(f"{date.strftime('%Y-%m-%d')}: Fear threshold hit ({fg_index:.1f}), starting transition to cash") elif self.state == 'FEAR_TRANSITION': self.gradual_transition(date, 'CASH') self.current_step += 1 if self.current_step >= self.transition_steps: self.state = 'CASH_WAIT' if self.config['output_settings']['show_console_output']: print(f"{date.strftime('%Y-%m-%d')}: Transition to cash complete") elif self.state == 'CASH_WAIT': if fg_index >= self.fear_threshold and fear_start_date: top_stocks = self.select_volatile_stocks(fear_start_date, date) if top_stocks: self.state = 'GREED_TRANSITION' self.current_step = 0 self.transition_stocks = top_stocks self.start_transition(date, 'VOLATILE', top_stocks) if self.config['output_settings']['show_console_output']: print(f"{date.strftime('%Y-%m-%d')}: Fear recovered, starting transition to volatile stocks: {top_stocks}") else: self.state = 'QQQ_TRANSITION' self.current_step = 0 self.start_transition(date, 'QQQ') if self.config['output_settings']['show_console_output']: print(f"{date.strftime('%Y-%m-%d')}: Fear recovered, no suitable stocks, returning to QQQ") elif self.state == 'GREED_TRANSITION': self.gradual_transition(date, 'VOLATILE', self.transition_stocks) self.current_step += 1 if self.current_step >= self.transition_steps: self.state = 'VOLATILE_STOCKS' if self.config['output_settings']['show_console_output']: print(f"{date.strftime('%Y-%m-%d')}: Transition to volatile stocks complete") elif self.state == 'VOLATILE_STOCKS': if fg_index > self.greed_threshold: self.state = 'QQQ_TRANSITION' self.current_step = 0 self.start_transition(date, 'QQQ') if self.config['output_settings']['show_console_output']: print(f"{date.strftime('%Y-%m-%d')}: Greed threshold hit ({fg_index:.1f}), starting transition to QQQ") elif self.state == 'QQQ_TRANSITION': self.gradual_transition(date, 'QQQ') self.current_step += 1 if self.current_step >= self.transition_steps: self.state = 'QQQ_HOLD' if self.config['output_settings']['show_console_output']: print(f"{date.strftime('%Y-%m-%d')}: Transition to QQQ complete") # Record portfolio value portfolio_value = self.calculate_portfolio_value(date) self.portfolio_value.append({ 'date': date, 'value': portfolio_value, 'state': self.state, 'fg_index': fg_index }) if self.config['output_settings']['show_console_output']: print("") print(f"✅ 回测完成! 总交易数: {len(self.trades)}") print("") def calculate_performance_metrics(self, returns): """Calculate performance metrics""" total_return = (returns.iloc[-1] / returns.iloc[0] - 1) * 100 annual_return = ((returns.iloc[-1] / returns.iloc[0]) ** (252 / len(returns)) - 1) * 100 # Calculate max drawdown peak = returns.expanding().max() drawdown = (returns - peak) / peak max_drawdown = drawdown.min() * 100 # Calculate Sharpe ratio daily_returns = returns.pct_change().dropna() sharpe_ratio = np.sqrt(252) * daily_returns.mean() / daily_returns.std() # Annual returns by year annual_rets = {} for year in returns.index.year.unique(): year_data = returns[returns.index.year == year] if len(year_data) > 1: year_return = (year_data.iloc[-1] / year_data.iloc[0] - 1) * 100 annual_rets[year] = year_return return { 'total_return': total_return, 'annual_return': annual_return, 'max_drawdown': max_drawdown, 'sharpe_ratio': sharpe_ratio, 'annual_returns': annual_rets } def generate_reports(strategy): """Generate all reports based on config settings""" config = strategy.config # Create output directories reports_dir = config['paths']['reports_dir'] results_dir = config['paths']['results_dir'] os.makedirs(reports_dir, exist_ok=True) os.makedirs(results_dir, exist_ok=True) timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # Save strategy object if requested if config['output_settings']['save_strategy_object']: strategy_file = os.path.join(results_dir, f'enhanced_strategy_{timestamp}.pkl') with open(strategy_file, 'wb') as f: pickle.dump(strategy, f) print(f"📦 策略对象已保存: {strategy_file}") # Generate CSV reports if config['output_settings']['generate_csv'] or config['output_settings']['generate_detailed_trades']: trades_df = pd.DataFrame(strategy.trades) if config['output_settings']['generate_csv']: csv_file = os.path.join(reports_dir, f'enhanced_trades_{timestamp}.csv') trades_df.to_csv(csv_file, index=False) print(f"📊 交易CSV已保存: {csv_file}") if config['output_settings']['generate_detailed_trades']: detailed_file = os.path.join(reports_dir, f'detailed_trades_{timestamp}.txt') with open(detailed_file, 'w', encoding='utf-8') as f: f.write("Enhanced Emotional Damage Strategy - Detailed Trades Report\n") f.write("=" * 120 + "\n\n") # Summary f.write(f"📊 交易摘要:\n") f.write(f"总交易数: {len(strategy.trades)}\n") f.write(f"交易时间: {trades_df['date'].min().strftime('%Y-%m-%d')} 到 {trades_df['date'].max().strftime('%Y-%m-%d')}\n") # Trade types action_counts = trades_df['action'].value_counts() f.write(f"\n交易类型统计:\n") for action, count in action_counts.items(): f.write(f" {action}: {count}\n") # Detailed trades f.write(f"\n📋 详细交易记录:\n") f.write("=" * 150 + "\n") f.write(f"{'No':>3s} {'Date':>10s} {'Action':>15s} {'Ticker':>5s} {'Shares':>8s} {'Price':>7s} {'Value':>12s} {'F&G':>4s} {'Cash':>12s} {'Assets':>12s} {'State':>15s}\n") f.write("=" * 150 + "\n") for i, trade in enumerate(strategy.trades, 1): fg_str = f"{trade.get('fg_index', 0):.0f}" if trade.get('fg_index') else "N/A" cash_str = f"${trade.get('cash_after', 0):,.0f}" if trade.get('cash_after') else "N/A" assets_str = f"${trade.get('total_assets', 0):,.0f}" if trade.get('total_assets') else "N/A" state_str = trade.get('portfolio_state', 'N/A') f.write(f"{i:3d} {trade['date'].strftime('%Y-%m-%d'):>10s} {trade['action']:>15s} {trade['ticker']:>5s} " f"{trade['shares']:>8.0f} ${trade['price']:>7.2f} ${trade['value']:>12,.0f} {fg_str:>4s} {cash_str:>12s} {assets_str:>12s} {state_str:>15s}\n") print(f"📝 详细交易报告已保存: {detailed_file}") # Generate PDF report if config['output_settings']['generate_pdf']: try: generate_enhanced_pdf_report(strategy, reports_dir, timestamp) except Exception as e: print(f"⚠️ PDF生成失败: {e}") import traceback traceback.print_exc() print("\n🎉 所有报告生成完成!") def generate_enhanced_pdf_report(strategy, reports_dir, timestamp): """Generate comprehensive PDF report with enhanced layout and proper spacing""" import matplotlib.pyplot as plt import matplotlib.dates as mdates from matplotlib.backends.backend_pdf import PdfPages import seaborn as sns import sqlite3 import os # Prepare data portfolio_df = pd.DataFrame(strategy.portfolio_value) portfolio_df.set_index('date', inplace=True) trades_df = pd.DataFrame(strategy.trades) # Get benchmark data db_path = strategy.config['paths']['database_path'] conn = sqlite3.connect(db_path) qqq_data = pd.read_sql_query(''' SELECT date, close as qqq_close FROM qqq ORDER BY date ''', conn) qqq_data['date'] = pd.to_datetime(qqq_data['date']) qqq_data.set_index('date', inplace=True) spy_data = pd.read_sql_query(''' SELECT date, spy_close FROM fear_greed_data ORDER BY date ''', conn) spy_data['date'] = pd.to_datetime(spy_data['date']) spy_data.set_index('date', inplace=True) conn.close() # Merge and align data benchmark_data = pd.merge(qqq_data, spy_data, left_index=True, right_index=True, how='inner') common_dates = portfolio_df.index.intersection(benchmark_data.index) portfolio_df = portfolio_df.loc[common_dates] benchmark_data = benchmark_data.loc[common_dates] # Normalize benchmarks start_value = strategy.initial_capital benchmark_data['qqq_value'] = start_value * (benchmark_data['qqq_close'] / benchmark_data['qqq_close'].iloc[0]) benchmark_data['spy_value'] = start_value * (benchmark_data['spy_close'] / benchmark_data['spy_close'].iloc[0]) # Calculate metrics strategy_metrics = strategy.calculate_performance_metrics(portfolio_df['value']) qqq_metrics = strategy.calculate_performance_metrics(benchmark_data['qqq_value']) spy_metrics = strategy.calculate_performance_metrics(benchmark_data['spy_value']) # Find max drawdown year def find_max_drawdown_year(returns): peak = returns.expanding().max() drawdown = (returns - peak) / peak max_dd_date = drawdown.idxmin() return max_dd_date.year strategy_dd_year = find_max_drawdown_year(portfolio_df['value']) qqq_dd_year = find_max_drawdown_year(benchmark_data['qqq_value']) spy_dd_year = find_max_drawdown_year(benchmark_data['spy_value']) # Create PDF with multiple pages pdf_file = os.path.join(reports_dir, f'enhanced_strategy_report_{timestamp}.pdf') with PdfPages(pdf_file) as pdf: # Set global font to support text and better spacing plt.rcParams['font.size'] = 10 plt.rcParams['axes.titlesize'] = 12 plt.rcParams['axes.labelsize'] = 10 plt.rcParams['xtick.labelsize'] = 8 plt.rcParams['ytick.labelsize'] = 8 plt.rcParams['legend.fontsize'] = 8 plt.rcParams['figure.titlesize'] = 14 # Page 1: Performance Comparison (Full Width) fig1 = plt.figure(figsize=(8.5, 11)) fig1.suptitle('Enhanced Emotional Damage Strategy Report', fontsize=16, fontweight='bold', y=0.96) # 1. Total Return Curve (Full width) ax1 = plt.subplot(4, 1, 1) ax1.plot(portfolio_df.index, portfolio_df['value'] / 1000, label='Enhanced Strategy', linewidth=2, color='red') ax1.plot(benchmark_data.index, benchmark_data['qqq_value'] / 1000, label='QQQ', linewidth=2, color='blue') ax1.plot(benchmark_data.index, benchmark_data['spy_value'] / 1000, label='SPY', linewidth=2, color='green') ax1.set_title('Portfolio Performance Comparison', fontsize=14, fontweight='bold', pad=25) ax1.set_ylabel('Portfolio Value ($K)', fontsize=11) ax1.legend(fontsize=10, loc='upper left') ax1.grid(True, alpha=0.3) ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y')) for label in ax1.get_xticklabels(): label.set_rotation(45) # 2. Performance Metrics Table (Full width) ax2 = plt.subplot(4, 1, 2) ax2.axis('off') metrics_data = [ ['Metric', 'Enhanced Strategy', 'QQQ', 'SPY'], ['Total Return', f"{strategy_metrics['total_return']:.1f}%", f"{qqq_metrics['total_return']:.1f}%", f"{spy_metrics['total_return']:.1f}%"], ['Annual Return', f"{strategy_metrics['annual_return']:.1f}%", f"{qqq_metrics['annual_return']:.1f}%", f"{spy_metrics['annual_return']:.1f}%"], ['Max Drawdown', f"{strategy_metrics['max_drawdown']:.1f}%", f"{qqq_metrics['max_drawdown']:.1f}%", f"{spy_metrics['max_drawdown']:.1f}%"], ['Max DD Year', str(strategy_dd_year), str(qqq_dd_year), str(spy_dd_year)], ['Sharpe Ratio', f"{strategy_metrics['sharpe_ratio']:.2f}", f"{qqq_metrics['sharpe_ratio']:.2f}", f"{spy_metrics['sharpe_ratio']:.2f}"], ['Total Trades', f"{len(strategy.trades)}", 'N/A', 'N/A'] ] table = ax2.table(cellText=metrics_data, cellLoc='center', loc='center') table.auto_set_font_size(False) table.set_fontsize(10) table.scale(1.2, 2.0) # More height for readability for i in range(len(metrics_data[0])): table[(0, i)].set_facecolor('#40466e') table[(0, i)].set_text_props(weight='bold', color='white') ax2.set_title('Performance Metrics Comparison', fontsize=14, fontweight='bold', pad=25) # 3. Strategy State Timeline (Full width) ax3 = plt.subplot(4, 1, 3) # Create state mapping and colors state_colors = { 'QQQ_HOLD': 'blue', 'FEAR_TRANSITION': 'orange', 'CASH_WAIT': 'gray', 'GREED_TRANSITION': 'yellow', 'VOLATILE_STOCKS': 'red', 'QQQ_TRANSITION': 'green' } # Plot portfolio value with state colors for state, color in state_colors.items(): state_data = portfolio_df[portfolio_df['state'] == state] if not state_data.empty: ax3.scatter(state_data.index, state_data['value'] / 1000, c=color, s=2, alpha=0.8, label=state) # Add stop-loss markers stop_loss_trades = trades_df[trades_df['action'] == 'STOP_LOSS'] if not stop_loss_trades.empty: for _, trade in stop_loss_trades.iterrows(): ax3.axvline(x=trade['date'], color='red', linestyle='--', alpha=0.8, linewidth=1) ax3.set_title('Strategy State Timeline with Stop-Loss Events', fontsize=14, fontweight='bold', pad=25) ax3.set_ylabel('Total Assets ($K)', fontsize=11) ax3.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8) ax3.grid(True, alpha=0.3) ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y')) for label in ax3.get_xticklabels(): label.set_rotation(45) # 4. Annual Returns Comparison (Full width) ax4 = plt.subplot(4, 1, 4) years = list(strategy_metrics['annual_returns'].keys()) enhanced_returns = list(strategy_metrics['annual_returns'].values()) qqq_returns = [qqq_metrics['annual_returns'].get(year, 0) for year in years] spy_returns = [spy_metrics['annual_returns'].get(year, 0) for year in years] x = np.arange(len(years)) width = 0.25 ax4.bar(x - width, enhanced_returns, width, label='Enhanced Strategy', color='red', alpha=0.8) ax4.bar(x, qqq_returns, width, label='QQQ', color='blue', alpha=0.8) ax4.bar(x + width, spy_returns, width, label='SPY', color='green', alpha=0.8) ax4.set_title('Annual Returns Comparison by Year', fontsize=14, fontweight='bold', pad=25) ax4.set_ylabel('Annual Return (%)', fontsize=11) ax4.set_xlabel('Year', fontsize=11) ax4.set_xticks(x) ax4.set_xticklabels(years) for label in ax4.get_xticklabels(): label.set_rotation(90) # Vertical text for years ax4.legend(fontsize=10) ax4.grid(True, alpha=0.3, axis='y') plt.subplots_adjust(left=0.1, right=0.85, top=0.90, bottom=0.08, hspace=0.6) pdf.savefig(fig1, bbox_inches='tight', dpi=150) plt.close() # Page 2: Detailed Analysis fig2 = plt.figure(figsize=(8.5, 11)) fig2.suptitle('Detailed Trading and Risk Analysis', fontsize=16, fontweight='bold', y=0.95) # 5. Stop-Loss Analysis (Full width) ax5 = plt.subplot(3, 1, 1) if not stop_loss_trades.empty: stop_loss_trades_copy = stop_loss_trades.copy() stop_loss_trades_copy['year'] = stop_loss_trades_copy['date'].dt.year stop_loss_by_year = stop_loss_trades_copy.groupby('year').size() ax5.bar(stop_loss_by_year.index, stop_loss_by_year.values, color='red', alpha=0.8, width=0.6) ax5.set_title('Stop-Loss Triggers by Year', fontsize=14, fontweight='bold', pad=30) ax5.set_ylabel('Number of Stop-Loss Events', fontsize=11) ax5.set_xlabel('Year', fontsize=11) ax5.grid(True, alpha=0.3, axis='y') for label in ax5.get_xticklabels(): label.set_rotation(45) else: ax5.text(0.5, 0.5, 'No Stop-Loss Events Triggered', ha='center', va='center', transform=ax5.transAxes, fontsize=14, fontweight='bold') ax5.set_title('Stop-Loss Analysis', fontsize=14, fontweight='bold', pad=30) # 6. Trade Frequency Analysis (Full width) ax6 = plt.subplot(3, 1, 2) trades_df_copy = trades_df.copy() trades_df_copy['year'] = trades_df_copy['date'].dt.year trade_frequency = trades_df_copy.groupby('year').size() ax6.bar(trade_frequency.index, trade_frequency.values, color='purple', alpha=0.8, width=0.6) ax6.set_title('Trading Activity by Year', fontsize=14, fontweight='bold', pad=30) ax6.set_ylabel('Number of Trades', fontsize=11) ax6.set_xlabel('Year', fontsize=11) ax6.grid(True, alpha=0.3, axis='y') for label in ax6.get_xticklabels(): label.set_rotation(45) # 7. Fear & Greed Index with Trading Signals (Full width) ax7 = plt.subplot(3, 1, 3) # Plot Fear & Greed Index fg_data = portfolio_df['fg_index'].dropna() ax7.plot(fg_data.index, fg_data.values, color='purple', alpha=0.8, linewidth=1.5) ax7.axhline(y=25, color='red', linestyle='--', alpha=0.7, linewidth=2, label='Fear Threshold (25)') ax7.axhline(y=75, color='green', linestyle='--', alpha=0.7, linewidth=2, label='Greed Threshold (75)') ax7.fill_between(fg_data.index, 0, 25, alpha=0.2, color='red', label='Fear Zone') ax7.fill_between(fg_data.index, 75, 100, alpha=0.2, color='green', label='Greed Zone') # Add trade markers buy_trades = trades_df[trades_df['action'].str.contains('BUY')] sell_trades = trades_df[trades_df['action'].str.contains('SELL')] if not buy_trades.empty: ax7.scatter(buy_trades['date'], buy_trades['fg_index'], color='darkgreen', s=15, alpha=0.8, marker='^', label='Buy Signals', zorder=5) if not sell_trades.empty: ax7.scatter(sell_trades['date'], sell_trades['fg_index'], color='darkred', s=15, alpha=0.8, marker='v', label='Sell Signals', zorder=5) ax7.set_title('Fear & Greed Index with Trading Signals', fontsize=14, fontweight='bold', pad=30) ax7.set_ylabel('CNN Fear & Greed Index', fontsize=11) ax7.set_xlabel('Date', fontsize=11) ax7.set_ylim(0, 100) ax7.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8) ax7.grid(True, alpha=0.3) ax7.xaxis.set_major_formatter(mdates.DateFormatter('%Y')) for label in ax7.get_xticklabels(): label.set_rotation(45) plt.subplots_adjust(left=0.1, right=0.85, top=0.88, bottom=0.10, hspace=1.0) pdf.savefig(fig2, bbox_inches='tight', dpi=150) plt.close() # Test PDF readability by checking file size and trying to read it try: file_size = os.path.getsize(pdf_file) if file_size < 50000: # Less than 50KB might indicate issues print(f"⚠️ Warning: PDF file size seems small ({file_size} bytes)") else: print(f"📈 PDF报告已保存: {pdf_file} (Size: {file_size:,} bytes)") # Quick validation - ensure we can open the file with open(pdf_file, 'rb') as f: header = f.read(10) if not header.startswith(b'%PDF'): print(f"⚠️ Warning: Generated file may not be a valid PDF") else: print(f"✅ PDF file validation passed") except Exception as e: print(f"⚠️ Error validating PDF: {e}") print(f"📈 PDF报告已保存: {pdf_file}") def main(): """Main function to run strategy with config""" config_file = 'config.json' if not os.path.exists(config_file): print(f"❌ 配置文件未找到: {config_file}") print("请确保config.json文件存在于当前目录") return try: # Initialize and run strategy strategy = ConfigurableEmotionalDamageStrategy(config_file) strategy.run_backtest() # Generate reports generate_reports(strategy) except Exception as e: print(f"❌ 运行失败: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()