Automated backup - 20250719_000001
This commit is contained in:
960
backtest/strategy/emotional-damage/run_strategy_with_config.py
Normal file
960
backtest/strategy/emotional-damage/run_strategy_with_config.py
Normal file
@@ -0,0 +1,960 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import warnings
|
||||
import os
|
||||
import pickle
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
class ConfigurableEmotionalDamageStrategy:
|
||||
def __init__(self, config_path='config.json'):
|
||||
"""Initialize strategy with config file"""
|
||||
self.load_config(config_path)
|
||||
self.cash = self.config['strategy_params']['initial_capital']
|
||||
self.positions = {}
|
||||
self.portfolio_value = []
|
||||
self.trades = []
|
||||
|
||||
# State management
|
||||
self.state = 'QQQ_HOLD'
|
||||
self.current_step = 0
|
||||
self.target_allocation = {}
|
||||
self.last_fear_date = None
|
||||
|
||||
# For gradual transitions
|
||||
self.transition_plan = {}
|
||||
self.transition_cash_pool = 0
|
||||
|
||||
def load_config(self, config_path):
|
||||
"""Load configuration from JSON file"""
|
||||
with open(config_path, 'r') as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
# Set strategy parameters as attributes for easy access
|
||||
params = self.config['strategy_params']
|
||||
self.initial_capital = params['initial_capital']
|
||||
self.fear_threshold = params['fear_threshold']
|
||||
self.greed_threshold = params['greed_threshold']
|
||||
self.stop_loss_threshold = params['stop_loss_threshold']
|
||||
self.top_stocks_count = params['top_stocks_count']
|
||||
self.volatility_threshold = params['volatility_threshold']
|
||||
self.volatility_calculation_days = params['volatility_calculation_days']
|
||||
self.transition_steps = params['transition_steps']
|
||||
|
||||
# Technical indicators
|
||||
tech = self.config['technical_indicators']
|
||||
self.rsi_threshold = tech['rsi_threshold']
|
||||
self.required_indicators = tech['required_indicators']
|
||||
self.sma5_above_sma20 = tech['sma5_above_sma20']
|
||||
self.macd_convergence = tech['macd_convergence']
|
||||
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print(f"✅ 配置已加载:")
|
||||
print(f" 初始资金: ${self.initial_capital:,}")
|
||||
print(f" 恐慌阈值: {self.fear_threshold}")
|
||||
print(f" 贪婪阈值: {self.greed_threshold}")
|
||||
print(f" 止损阈值: {self.stop_loss_threshold*100}%")
|
||||
print(f" 转换步数: {self.transition_steps}")
|
||||
print(f" 选股数量: {self.top_stocks_count}")
|
||||
print("")
|
||||
|
||||
def get_data(self):
|
||||
"""Load Fear & Greed Index and stock data"""
|
||||
db_path = self.config['paths']['database_path']
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print(f"连接数据库: {db_path}")
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
# Get Fear & Greed Index
|
||||
fg_data = pd.read_sql_query('''
|
||||
SELECT date, fear_greed_index
|
||||
FROM fear_greed_index
|
||||
ORDER BY date
|
||||
''', conn)
|
||||
fg_data['date'] = pd.to_datetime(fg_data['date'])
|
||||
fg_data.set_index('date', inplace=True)
|
||||
|
||||
# Get real QQQ price data
|
||||
qqq_data = pd.read_sql_query('''
|
||||
SELECT date, close as qqq_close
|
||||
FROM qqq
|
||||
ORDER BY date
|
||||
''', conn)
|
||||
qqq_data['date'] = pd.to_datetime(qqq_data['date'])
|
||||
qqq_data.set_index('date', inplace=True)
|
||||
|
||||
# Get available tickers
|
||||
min_records = self.config['data_settings']['min_ticker_records']
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(f'SELECT ticker FROM ticker_list WHERE records > {min_records}')
|
||||
self.available_tickers = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
conn.close()
|
||||
|
||||
# Merge data
|
||||
self.data = pd.merge(fg_data, qqq_data, left_index=True, right_index=True, how='inner')
|
||||
|
||||
# Apply date filters if specified
|
||||
if self.config['data_settings']['start_date']:
|
||||
start_date = pd.to_datetime(self.config['data_settings']['start_date'])
|
||||
self.data = self.data[self.data.index >= start_date]
|
||||
|
||||
if self.config['data_settings']['end_date']:
|
||||
end_date = pd.to_datetime(self.config['data_settings']['end_date'])
|
||||
self.data = self.data[self.data.index <= end_date]
|
||||
|
||||
self.data.sort_index(inplace=True)
|
||||
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print(f"数据加载完成: {self.data.index.min().strftime('%Y-%m-%d')} 到 {self.data.index.max().strftime('%Y-%m-%d')}")
|
||||
print(f"可用股票数量: {len(self.available_tickers)}")
|
||||
print("")
|
||||
|
||||
def get_stock_price(self, ticker, date):
|
||||
"""Get stock price for a specific ticker and date"""
|
||||
db_path = self.config['paths']['database_path']
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
query = f'''
|
||||
SELECT close FROM {ticker.lower()}
|
||||
WHERE date <= ?
|
||||
ORDER BY date DESC
|
||||
LIMIT 1
|
||||
'''
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, (date.strftime('%Y-%m-%d'),))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
return result[0] if result else None
|
||||
|
||||
def calculate_volatility(self, ticker, current_date):
|
||||
"""Calculate historical volatility over configured period"""
|
||||
db_path = self.config['paths']['database_path']
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
try:
|
||||
start_date = current_date - timedelta(days=self.volatility_calculation_days)
|
||||
|
||||
query = f'''
|
||||
SELECT date, close FROM {ticker.lower()}
|
||||
WHERE date >= ? AND date <= ?
|
||||
ORDER BY date
|
||||
'''
|
||||
|
||||
df = pd.read_sql_query(query, conn, params=(
|
||||
start_date.strftime('%Y-%m-%d'),
|
||||
current_date.strftime('%Y-%m-%d')
|
||||
))
|
||||
|
||||
if len(df) > 10:
|
||||
df['returns'] = df['close'].pct_change()
|
||||
volatility = df['returns'].std() * np.sqrt(252)
|
||||
conn.close()
|
||||
return volatility
|
||||
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
conn.close()
|
||||
return 0
|
||||
|
||||
def check_technical_indicators(self, ticker, date):
|
||||
"""Check RSI, MACD, and SMA technical indicators"""
|
||||
db_path = self.config['paths']['database_path']
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
try:
|
||||
query = f'''
|
||||
SELECT date, close FROM {ticker.lower()}
|
||||
WHERE date <= ?
|
||||
ORDER BY date DESC
|
||||
LIMIT 50
|
||||
'''
|
||||
|
||||
df = pd.read_sql_query(query, conn, params=(date.strftime('%Y-%m-%d'),))
|
||||
|
||||
if len(df) < 20:
|
||||
conn.close()
|
||||
return False
|
||||
|
||||
df = df.sort_values('date')
|
||||
df.reset_index(drop=True, inplace=True)
|
||||
|
||||
# Calculate RSI
|
||||
rsi_period = self.config['technical_indicators']['rsi_period']
|
||||
delta = df['close'].diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=rsi_period).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=rsi_period).mean()
|
||||
rs = gain / loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
|
||||
# Calculate MACD
|
||||
ema_fast = self.config['technical_indicators']['ema_periods']['fast']
|
||||
ema_slow = self.config['technical_indicators']['ema_periods']['slow']
|
||||
ema_signal = self.config['technical_indicators']['ema_periods']['signal']
|
||||
|
||||
ema12 = df['close'].ewm(span=ema_fast).mean()
|
||||
ema26 = df['close'].ewm(span=ema_slow).mean()
|
||||
macd = ema12 - ema26
|
||||
signal = macd.ewm(span=ema_signal).mean()
|
||||
|
||||
# Calculate SMA
|
||||
sma_fast = self.config['technical_indicators']['sma_periods']['fast']
|
||||
sma_slow = self.config['technical_indicators']['sma_periods']['slow']
|
||||
sma5 = df['close'].rolling(window=sma_fast).mean()
|
||||
sma20 = df['close'].rolling(window=sma_slow).mean()
|
||||
|
||||
# Check conditions
|
||||
latest_rsi = rsi.iloc[-1]
|
||||
latest_macd = macd.iloc[-1]
|
||||
latest_signal = signal.iloc[-1]
|
||||
latest_sma5 = sma5.iloc[-1]
|
||||
latest_sma20 = sma20.iloc[-1]
|
||||
|
||||
# RSI condition
|
||||
rsi_ok = latest_rsi > self.rsi_threshold
|
||||
|
||||
# MACD condition
|
||||
if self.macd_convergence and len(macd) >= 2 and len(signal) >= 2:
|
||||
prev_macd = macd.iloc[-2]
|
||||
prev_signal = signal.iloc[-2]
|
||||
prev_diff = abs(prev_macd - prev_signal)
|
||||
current_diff = abs(latest_macd - latest_signal)
|
||||
macd_ok = current_diff < prev_diff # Lines are converging
|
||||
else:
|
||||
macd_ok = latest_macd > latest_signal # Traditional golden cross
|
||||
|
||||
# SMA condition
|
||||
if self.sma5_above_sma20:
|
||||
sma_ok = latest_sma5 > latest_sma20
|
||||
else:
|
||||
sma_ok = True # Skip SMA check if disabled
|
||||
|
||||
# Check if enough indicators are positive
|
||||
conditions = [rsi_ok, macd_ok, sma_ok]
|
||||
score = sum(conditions)
|
||||
|
||||
conn.close()
|
||||
return score >= self.required_indicators
|
||||
|
||||
except Exception as e:
|
||||
conn.close()
|
||||
return False
|
||||
|
||||
def select_volatile_stocks(self, fear_start_date, fear_end_date):
|
||||
"""Select stocks using technical indicators, then sort by volatility"""
|
||||
qualified_stocks = []
|
||||
|
||||
for ticker in self.available_tickers:
|
||||
if self.check_technical_indicators(ticker, fear_end_date):
|
||||
vol = self.calculate_volatility(ticker, fear_end_date)
|
||||
|
||||
if vol > self.volatility_threshold:
|
||||
qualified_stocks.append((ticker, vol))
|
||||
|
||||
# Sort by volatility and select top stocks
|
||||
qualified_stocks.sort(key=lambda x: x[1], reverse=True)
|
||||
top_stocks = [ticker for ticker, vol in qualified_stocks[:self.top_stocks_count]]
|
||||
|
||||
return top_stocks
|
||||
|
||||
def execute_trade(self, date, action, ticker=None, shares=None, price=None, value=None):
|
||||
"""Execute and record a trade"""
|
||||
fg_index = self.data.loc[date, 'fear_greed_index'] if date in self.data.index else None
|
||||
total_assets = self.calculate_portfolio_value(date)
|
||||
|
||||
self.trades.append({
|
||||
'date': date,
|
||||
'action': action,
|
||||
'ticker': ticker,
|
||||
'shares': shares,
|
||||
'price': price,
|
||||
'value': value,
|
||||
'fg_index': fg_index,
|
||||
'cnn_fear_greed': fg_index, # Same as fg_index but with clearer name
|
||||
'cash_after': self.cash,
|
||||
'total_assets': total_assets,
|
||||
'portfolio_state': self.state
|
||||
})
|
||||
|
||||
def calculate_portfolio_value(self, date):
|
||||
"""Calculate total portfolio value"""
|
||||
total_value = self.cash
|
||||
|
||||
for ticker, shares in self.positions.items():
|
||||
if ticker == 'QQQ':
|
||||
price = self.data.loc[date, 'qqq_close']
|
||||
else:
|
||||
price = self.get_stock_price(ticker, date)
|
||||
|
||||
if price:
|
||||
total_value += shares * price
|
||||
|
||||
return total_value
|
||||
|
||||
def check_stop_loss(self, date):
|
||||
"""Check stop loss threshold"""
|
||||
for ticker in list(self.positions.keys()):
|
||||
if ticker == 'QQQ':
|
||||
continue
|
||||
|
||||
current_price = self.get_stock_price(ticker, date)
|
||||
if not current_price:
|
||||
continue
|
||||
|
||||
# Find average buy price
|
||||
buy_trades = [t for t in self.trades
|
||||
if t['ticker'] == ticker and t['action'] in ['BUY_GRADUAL']]
|
||||
if buy_trades:
|
||||
total_cost = sum(t['price'] * t['shares'] for t in buy_trades)
|
||||
total_shares = sum(t['shares'] for t in buy_trades)
|
||||
avg_price = total_cost / total_shares
|
||||
|
||||
loss_pct = (current_price - avg_price) / avg_price
|
||||
if loss_pct <= -self.stop_loss_threshold:
|
||||
# Sell and buy QQQ
|
||||
shares = self.positions[ticker]
|
||||
value = shares * current_price
|
||||
self.cash += value
|
||||
del self.positions[ticker]
|
||||
|
||||
self.execute_trade(date, 'STOP_LOSS', ticker, shares, current_price, value)
|
||||
|
||||
# Buy QQQ with integer shares
|
||||
qqq_price = self.data.loc[date, 'qqq_close']
|
||||
qqq_shares = int(value / qqq_price)
|
||||
|
||||
if qqq_shares > 0:
|
||||
actual_qqq_value = qqq_shares * qqq_price
|
||||
self.positions['QQQ'] = self.positions.get('QQQ', 0) + qqq_shares
|
||||
self.cash -= actual_qqq_value
|
||||
self.execute_trade(date, 'BUY_QQQ_STOPLOSS', 'QQQ', qqq_shares, qqq_price, actual_qqq_value)
|
||||
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print(f"{date.strftime('%Y-%m-%d')}: Stop loss triggered for {ticker}, loss: {loss_pct*100:.1f}%")
|
||||
|
||||
def start_transition(self, date, target_type, stocks=None):
|
||||
"""Initialize transition plan"""
|
||||
self.transition_plan = {'type': target_type, 'stocks': stocks}
|
||||
|
||||
if target_type == 'CASH':
|
||||
self.transition_plan['positions_to_sell'] = {}
|
||||
for ticker in self.positions:
|
||||
self.transition_plan['positions_to_sell'][ticker] = self.positions[ticker]
|
||||
|
||||
elif target_type == 'QQQ':
|
||||
cash_from_positions = 0
|
||||
for ticker in self.positions:
|
||||
if ticker != 'QQQ':
|
||||
price = self.get_stock_price(ticker, date)
|
||||
if price:
|
||||
cash_from_positions += self.positions[ticker] * price
|
||||
|
||||
self.transition_cash_pool = self.cash + cash_from_positions
|
||||
self.transition_plan['total_cash_to_invest'] = self.transition_cash_pool
|
||||
self.transition_plan['positions_to_sell'] = {}
|
||||
for ticker in self.positions:
|
||||
if ticker != 'QQQ':
|
||||
self.transition_plan['positions_to_sell'][ticker] = self.positions[ticker]
|
||||
|
||||
elif target_type == 'VOLATILE' and stocks:
|
||||
cash_from_positions = 0
|
||||
for ticker in self.positions:
|
||||
if ticker != 'QQQ':
|
||||
price = self.get_stock_price(ticker, date)
|
||||
if price:
|
||||
cash_from_positions += self.positions[ticker] * price
|
||||
|
||||
total_available_cash = self.cash + cash_from_positions
|
||||
self.transition_plan['total_cash_to_invest'] = total_available_cash
|
||||
|
||||
def gradual_transition(self, date, target_type, stocks=None):
|
||||
"""Handle gradual transitions with integer shares"""
|
||||
step_size = 1.0 / self.transition_steps
|
||||
|
||||
if target_type == 'CASH':
|
||||
for ticker in list(self.transition_plan.get('positions_to_sell', {})):
|
||||
if ticker in self.positions:
|
||||
total_shares_to_sell = self.transition_plan['positions_to_sell'][ticker]
|
||||
shares_to_sell = int(total_shares_to_sell * step_size)
|
||||
|
||||
if shares_to_sell > 0 and shares_to_sell <= self.positions[ticker]:
|
||||
price = self.get_stock_price(ticker, date)
|
||||
if price:
|
||||
value = shares_to_sell * price
|
||||
self.cash += value
|
||||
self.positions[ticker] -= shares_to_sell
|
||||
if self.positions[ticker] <= 0:
|
||||
del self.positions[ticker]
|
||||
self.execute_trade(date, 'SELL_GRADUAL', ticker, shares_to_sell, price, value)
|
||||
|
||||
elif target_type == 'VOLATILE' and stocks:
|
||||
total_cash = self.transition_plan.get('total_cash_to_invest', 0)
|
||||
cash_this_step = total_cash * step_size
|
||||
|
||||
if cash_this_step > 0 and self.cash >= cash_this_step:
|
||||
current_step_index = min(self.current_step, len(stocks) - 1)
|
||||
ticker = stocks[current_step_index]
|
||||
|
||||
price = self.get_stock_price(ticker, date)
|
||||
if price and cash_this_step > 0:
|
||||
shares = int(cash_this_step / price)
|
||||
if shares > 0:
|
||||
actual_value = shares * price
|
||||
self.positions[ticker] = self.positions.get(ticker, 0) + shares
|
||||
self.cash -= actual_value
|
||||
self.execute_trade(date, 'BUY_GRADUAL', ticker, shares, price, actual_value)
|
||||
|
||||
elif target_type == 'QQQ':
|
||||
# Sell positions gradually
|
||||
for ticker in list(self.transition_plan.get('positions_to_sell', {})):
|
||||
if ticker in self.positions:
|
||||
total_shares_to_sell = self.transition_plan['positions_to_sell'][ticker]
|
||||
shares_to_sell = int(total_shares_to_sell * step_size)
|
||||
|
||||
if shares_to_sell > 0 and shares_to_sell <= self.positions[ticker]:
|
||||
price = self.get_stock_price(ticker, date)
|
||||
if price:
|
||||
value = shares_to_sell * price
|
||||
self.cash += value
|
||||
self.positions[ticker] -= shares_to_sell
|
||||
if self.positions[ticker] <= 0:
|
||||
del self.positions[ticker]
|
||||
self.execute_trade(date, 'SELL_GRADUAL', ticker, shares_to_sell, price, value)
|
||||
|
||||
# Buy QQQ
|
||||
total_cash = self.transition_plan.get('total_cash_to_invest', 0)
|
||||
cash_this_step = total_cash * step_size
|
||||
|
||||
if cash_this_step > 0 and self.cash >= cash_this_step:
|
||||
qqq_price = self.data.loc[date, 'qqq_close']
|
||||
qqq_shares = int(cash_this_step / qqq_price)
|
||||
|
||||
if qqq_shares > 0:
|
||||
actual_value = qqq_shares * qqq_price
|
||||
self.positions['QQQ'] = self.positions.get('QQQ', 0) + qqq_shares
|
||||
self.cash -= actual_value
|
||||
self.execute_trade(date, 'BUY_GRADUAL', 'QQQ', qqq_shares, qqq_price, actual_value)
|
||||
|
||||
def run_backtest(self):
|
||||
"""Run the strategy backtest"""
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print("🚀 开始运行Enhanced Emotional Damage Strategy...")
|
||||
print("")
|
||||
|
||||
self.get_data()
|
||||
|
||||
# Start with 100% QQQ
|
||||
first_date = self.data.index[0]
|
||||
qqq_price = self.data.loc[first_date, 'qqq_close']
|
||||
qqq_shares = int(self.initial_capital / qqq_price)
|
||||
self.positions['QQQ'] = qqq_shares
|
||||
self.cash = self.initial_capital - (qqq_shares * qqq_price)
|
||||
|
||||
fear_start_date = None
|
||||
|
||||
for date, row in self.data.iterrows():
|
||||
fg_index = row['fear_greed_index']
|
||||
|
||||
# Check stop loss
|
||||
self.check_stop_loss(date)
|
||||
|
||||
if self.state == 'QQQ_HOLD':
|
||||
if fg_index < self.fear_threshold:
|
||||
fear_start_date = date
|
||||
self.state = 'FEAR_TRANSITION'
|
||||
self.current_step = 0
|
||||
self.start_transition(date, 'CASH')
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print(f"{date.strftime('%Y-%m-%d')}: Fear threshold hit ({fg_index:.1f}), starting transition to cash")
|
||||
|
||||
elif self.state == 'FEAR_TRANSITION':
|
||||
self.gradual_transition(date, 'CASH')
|
||||
self.current_step += 1
|
||||
|
||||
if self.current_step >= self.transition_steps:
|
||||
self.state = 'CASH_WAIT'
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print(f"{date.strftime('%Y-%m-%d')}: Transition to cash complete")
|
||||
|
||||
elif self.state == 'CASH_WAIT':
|
||||
if fg_index >= self.fear_threshold and fear_start_date:
|
||||
top_stocks = self.select_volatile_stocks(fear_start_date, date)
|
||||
|
||||
if top_stocks:
|
||||
self.state = 'GREED_TRANSITION'
|
||||
self.current_step = 0
|
||||
self.transition_stocks = top_stocks
|
||||
self.start_transition(date, 'VOLATILE', top_stocks)
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print(f"{date.strftime('%Y-%m-%d')}: Fear recovered, starting transition to volatile stocks: {top_stocks}")
|
||||
else:
|
||||
self.state = 'QQQ_TRANSITION'
|
||||
self.current_step = 0
|
||||
self.start_transition(date, 'QQQ')
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print(f"{date.strftime('%Y-%m-%d')}: Fear recovered, no suitable stocks, returning to QQQ")
|
||||
|
||||
elif self.state == 'GREED_TRANSITION':
|
||||
self.gradual_transition(date, 'VOLATILE', self.transition_stocks)
|
||||
self.current_step += 1
|
||||
|
||||
if self.current_step >= self.transition_steps:
|
||||
self.state = 'VOLATILE_STOCKS'
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print(f"{date.strftime('%Y-%m-%d')}: Transition to volatile stocks complete")
|
||||
|
||||
elif self.state == 'VOLATILE_STOCKS':
|
||||
if fg_index > self.greed_threshold:
|
||||
self.state = 'QQQ_TRANSITION'
|
||||
self.current_step = 0
|
||||
self.start_transition(date, 'QQQ')
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print(f"{date.strftime('%Y-%m-%d')}: Greed threshold hit ({fg_index:.1f}), starting transition to QQQ")
|
||||
|
||||
elif self.state == 'QQQ_TRANSITION':
|
||||
self.gradual_transition(date, 'QQQ')
|
||||
self.current_step += 1
|
||||
|
||||
if self.current_step >= self.transition_steps:
|
||||
self.state = 'QQQ_HOLD'
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print(f"{date.strftime('%Y-%m-%d')}: Transition to QQQ complete")
|
||||
|
||||
# Record portfolio value
|
||||
portfolio_value = self.calculate_portfolio_value(date)
|
||||
self.portfolio_value.append({
|
||||
'date': date,
|
||||
'value': portfolio_value,
|
||||
'state': self.state,
|
||||
'fg_index': fg_index
|
||||
})
|
||||
|
||||
if self.config['output_settings']['show_console_output']:
|
||||
print("")
|
||||
print(f"✅ 回测完成! 总交易数: {len(self.trades)}")
|
||||
print("")
|
||||
|
||||
def calculate_performance_metrics(self, returns):
|
||||
"""Calculate performance metrics"""
|
||||
total_return = (returns.iloc[-1] / returns.iloc[0] - 1) * 100
|
||||
annual_return = ((returns.iloc[-1] / returns.iloc[0]) ** (252 / len(returns)) - 1) * 100
|
||||
|
||||
# Calculate max drawdown
|
||||
peak = returns.expanding().max()
|
||||
drawdown = (returns - peak) / peak
|
||||
max_drawdown = drawdown.min() * 100
|
||||
|
||||
# Calculate Sharpe ratio
|
||||
daily_returns = returns.pct_change().dropna()
|
||||
sharpe_ratio = np.sqrt(252) * daily_returns.mean() / daily_returns.std()
|
||||
|
||||
# Annual returns by year
|
||||
annual_rets = {}
|
||||
for year in returns.index.year.unique():
|
||||
year_data = returns[returns.index.year == year]
|
||||
if len(year_data) > 1:
|
||||
year_return = (year_data.iloc[-1] / year_data.iloc[0] - 1) * 100
|
||||
annual_rets[year] = year_return
|
||||
|
||||
return {
|
||||
'total_return': total_return,
|
||||
'annual_return': annual_return,
|
||||
'max_drawdown': max_drawdown,
|
||||
'sharpe_ratio': sharpe_ratio,
|
||||
'annual_returns': annual_rets
|
||||
}
|
||||
|
||||
|
||||
def generate_reports(strategy):
|
||||
"""Generate all reports based on config settings"""
|
||||
config = strategy.config
|
||||
|
||||
# Create output directories
|
||||
reports_dir = config['paths']['reports_dir']
|
||||
results_dir = config['paths']['results_dir']
|
||||
os.makedirs(reports_dir, exist_ok=True)
|
||||
os.makedirs(results_dir, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
|
||||
# Save strategy object if requested
|
||||
if config['output_settings']['save_strategy_object']:
|
||||
strategy_file = os.path.join(results_dir, f'enhanced_strategy_{timestamp}.pkl')
|
||||
with open(strategy_file, 'wb') as f:
|
||||
pickle.dump(strategy, f)
|
||||
print(f"📦 策略对象已保存: {strategy_file}")
|
||||
|
||||
# Generate CSV reports
|
||||
if config['output_settings']['generate_csv'] or config['output_settings']['generate_detailed_trades']:
|
||||
trades_df = pd.DataFrame(strategy.trades)
|
||||
|
||||
if config['output_settings']['generate_csv']:
|
||||
csv_file = os.path.join(reports_dir, f'enhanced_trades_{timestamp}.csv')
|
||||
trades_df.to_csv(csv_file, index=False)
|
||||
print(f"📊 交易CSV已保存: {csv_file}")
|
||||
|
||||
if config['output_settings']['generate_detailed_trades']:
|
||||
detailed_file = os.path.join(reports_dir, f'detailed_trades_{timestamp}.txt')
|
||||
with open(detailed_file, 'w', encoding='utf-8') as f:
|
||||
f.write("Enhanced Emotional Damage Strategy - Detailed Trades Report\n")
|
||||
f.write("=" * 120 + "\n\n")
|
||||
|
||||
# Summary
|
||||
f.write(f"📊 交易摘要:\n")
|
||||
f.write(f"总交易数: {len(strategy.trades)}\n")
|
||||
f.write(f"交易时间: {trades_df['date'].min().strftime('%Y-%m-%d')} 到 {trades_df['date'].max().strftime('%Y-%m-%d')}\n")
|
||||
|
||||
# Trade types
|
||||
action_counts = trades_df['action'].value_counts()
|
||||
f.write(f"\n交易类型统计:\n")
|
||||
for action, count in action_counts.items():
|
||||
f.write(f" {action}: {count}\n")
|
||||
|
||||
# Detailed trades
|
||||
f.write(f"\n📋 详细交易记录:\n")
|
||||
f.write("=" * 150 + "\n")
|
||||
f.write(f"{'No':>3s} {'Date':>10s} {'Action':>15s} {'Ticker':>5s} {'Shares':>8s} {'Price':>7s} {'Value':>12s} {'F&G':>4s} {'Cash':>12s} {'Assets':>12s} {'State':>15s}\n")
|
||||
f.write("=" * 150 + "\n")
|
||||
|
||||
for i, trade in enumerate(strategy.trades, 1):
|
||||
fg_str = f"{trade.get('fg_index', 0):.0f}" if trade.get('fg_index') else "N/A"
|
||||
cash_str = f"${trade.get('cash_after', 0):,.0f}" if trade.get('cash_after') else "N/A"
|
||||
assets_str = f"${trade.get('total_assets', 0):,.0f}" if trade.get('total_assets') else "N/A"
|
||||
state_str = trade.get('portfolio_state', 'N/A')
|
||||
f.write(f"{i:3d} {trade['date'].strftime('%Y-%m-%d'):>10s} {trade['action']:>15s} {trade['ticker']:>5s} "
|
||||
f"{trade['shares']:>8.0f} ${trade['price']:>7.2f} ${trade['value']:>12,.0f} {fg_str:>4s} {cash_str:>12s} {assets_str:>12s} {state_str:>15s}\n")
|
||||
|
||||
print(f"📝 详细交易报告已保存: {detailed_file}")
|
||||
|
||||
# Generate PDF report
|
||||
if config['output_settings']['generate_pdf']:
|
||||
try:
|
||||
generate_enhanced_pdf_report(strategy, reports_dir, timestamp)
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ PDF生成失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n🎉 所有报告生成完成!")
|
||||
|
||||
|
||||
def generate_enhanced_pdf_report(strategy, reports_dir, timestamp):
|
||||
"""Generate comprehensive PDF report with enhanced layout and proper spacing"""
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.dates as mdates
|
||||
from matplotlib.backends.backend_pdf import PdfPages
|
||||
import seaborn as sns
|
||||
import sqlite3
|
||||
import os
|
||||
|
||||
# Prepare data
|
||||
portfolio_df = pd.DataFrame(strategy.portfolio_value)
|
||||
portfolio_df.set_index('date', inplace=True)
|
||||
trades_df = pd.DataFrame(strategy.trades)
|
||||
|
||||
# Get benchmark data
|
||||
db_path = strategy.config['paths']['database_path']
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
qqq_data = pd.read_sql_query('''
|
||||
SELECT date, close as qqq_close
|
||||
FROM qqq
|
||||
ORDER BY date
|
||||
''', conn)
|
||||
qqq_data['date'] = pd.to_datetime(qqq_data['date'])
|
||||
qqq_data.set_index('date', inplace=True)
|
||||
|
||||
spy_data = pd.read_sql_query('''
|
||||
SELECT date, spy_close
|
||||
FROM fear_greed_data
|
||||
ORDER BY date
|
||||
''', conn)
|
||||
spy_data['date'] = pd.to_datetime(spy_data['date'])
|
||||
spy_data.set_index('date', inplace=True)
|
||||
|
||||
conn.close()
|
||||
|
||||
# Merge and align data
|
||||
benchmark_data = pd.merge(qqq_data, spy_data, left_index=True, right_index=True, how='inner')
|
||||
common_dates = portfolio_df.index.intersection(benchmark_data.index)
|
||||
portfolio_df = portfolio_df.loc[common_dates]
|
||||
benchmark_data = benchmark_data.loc[common_dates]
|
||||
|
||||
# Normalize benchmarks
|
||||
start_value = strategy.initial_capital
|
||||
benchmark_data['qqq_value'] = start_value * (benchmark_data['qqq_close'] / benchmark_data['qqq_close'].iloc[0])
|
||||
benchmark_data['spy_value'] = start_value * (benchmark_data['spy_close'] / benchmark_data['spy_close'].iloc[0])
|
||||
|
||||
# Calculate metrics
|
||||
strategy_metrics = strategy.calculate_performance_metrics(portfolio_df['value'])
|
||||
qqq_metrics = strategy.calculate_performance_metrics(benchmark_data['qqq_value'])
|
||||
spy_metrics = strategy.calculate_performance_metrics(benchmark_data['spy_value'])
|
||||
|
||||
# Find max drawdown year
|
||||
def find_max_drawdown_year(returns):
|
||||
peak = returns.expanding().max()
|
||||
drawdown = (returns - peak) / peak
|
||||
max_dd_date = drawdown.idxmin()
|
||||
return max_dd_date.year
|
||||
|
||||
strategy_dd_year = find_max_drawdown_year(portfolio_df['value'])
|
||||
qqq_dd_year = find_max_drawdown_year(benchmark_data['qqq_value'])
|
||||
spy_dd_year = find_max_drawdown_year(benchmark_data['spy_value'])
|
||||
|
||||
# Create PDF with multiple pages
|
||||
pdf_file = os.path.join(reports_dir, f'enhanced_strategy_report_{timestamp}.pdf')
|
||||
|
||||
with PdfPages(pdf_file) as pdf:
|
||||
# Set global font to support text and better spacing
|
||||
plt.rcParams['font.size'] = 10
|
||||
plt.rcParams['axes.titlesize'] = 12
|
||||
plt.rcParams['axes.labelsize'] = 10
|
||||
plt.rcParams['xtick.labelsize'] = 8
|
||||
plt.rcParams['ytick.labelsize'] = 8
|
||||
plt.rcParams['legend.fontsize'] = 8
|
||||
plt.rcParams['figure.titlesize'] = 14
|
||||
|
||||
# Page 1: Performance Comparison (Full Width)
|
||||
fig1 = plt.figure(figsize=(8.5, 11))
|
||||
fig1.suptitle('Enhanced Emotional Damage Strategy Report', fontsize=16, fontweight='bold', y=0.96)
|
||||
|
||||
# 1. Total Return Curve (Full width)
|
||||
ax1 = plt.subplot(4, 1, 1)
|
||||
ax1.plot(portfolio_df.index, portfolio_df['value'] / 1000,
|
||||
label='Enhanced Strategy', linewidth=2, color='red')
|
||||
ax1.plot(benchmark_data.index, benchmark_data['qqq_value'] / 1000,
|
||||
label='QQQ', linewidth=2, color='blue')
|
||||
ax1.plot(benchmark_data.index, benchmark_data['spy_value'] / 1000,
|
||||
label='SPY', linewidth=2, color='green')
|
||||
ax1.set_title('Portfolio Performance Comparison', fontsize=14, fontweight='bold', pad=25)
|
||||
ax1.set_ylabel('Portfolio Value ($K)', fontsize=11)
|
||||
ax1.legend(fontsize=10, loc='upper left')
|
||||
ax1.grid(True, alpha=0.3)
|
||||
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
|
||||
for label in ax1.get_xticklabels():
|
||||
label.set_rotation(45)
|
||||
|
||||
# 2. Performance Metrics Table (Full width)
|
||||
ax2 = plt.subplot(4, 1, 2)
|
||||
ax2.axis('off')
|
||||
|
||||
metrics_data = [
|
||||
['Metric', 'Enhanced Strategy', 'QQQ', 'SPY'],
|
||||
['Total Return', f"{strategy_metrics['total_return']:.1f}%",
|
||||
f"{qqq_metrics['total_return']:.1f}%", f"{spy_metrics['total_return']:.1f}%"],
|
||||
['Annual Return', f"{strategy_metrics['annual_return']:.1f}%",
|
||||
f"{qqq_metrics['annual_return']:.1f}%", f"{spy_metrics['annual_return']:.1f}%"],
|
||||
['Max Drawdown', f"{strategy_metrics['max_drawdown']:.1f}%",
|
||||
f"{qqq_metrics['max_drawdown']:.1f}%", f"{spy_metrics['max_drawdown']:.1f}%"],
|
||||
['Max DD Year', str(strategy_dd_year), str(qqq_dd_year), str(spy_dd_year)],
|
||||
['Sharpe Ratio', f"{strategy_metrics['sharpe_ratio']:.2f}",
|
||||
f"{qqq_metrics['sharpe_ratio']:.2f}", f"{spy_metrics['sharpe_ratio']:.2f}"],
|
||||
['Total Trades', f"{len(strategy.trades)}", 'N/A', 'N/A']
|
||||
]
|
||||
|
||||
table = ax2.table(cellText=metrics_data, cellLoc='center', loc='center')
|
||||
table.auto_set_font_size(False)
|
||||
table.set_fontsize(10)
|
||||
table.scale(1.2, 2.0) # More height for readability
|
||||
|
||||
for i in range(len(metrics_data[0])):
|
||||
table[(0, i)].set_facecolor('#40466e')
|
||||
table[(0, i)].set_text_props(weight='bold', color='white')
|
||||
|
||||
ax2.set_title('Performance Metrics Comparison', fontsize=14, fontweight='bold', pad=25)
|
||||
|
||||
# 3. Strategy State Timeline (Full width)
|
||||
ax3 = plt.subplot(4, 1, 3)
|
||||
|
||||
# Create state mapping and colors
|
||||
state_colors = {
|
||||
'QQQ_HOLD': 'blue',
|
||||
'FEAR_TRANSITION': 'orange',
|
||||
'CASH_WAIT': 'gray',
|
||||
'GREED_TRANSITION': 'yellow',
|
||||
'VOLATILE_STOCKS': 'red',
|
||||
'QQQ_TRANSITION': 'green'
|
||||
}
|
||||
|
||||
# Plot portfolio value with state colors
|
||||
for state, color in state_colors.items():
|
||||
state_data = portfolio_df[portfolio_df['state'] == state]
|
||||
if not state_data.empty:
|
||||
ax3.scatter(state_data.index, state_data['value'] / 1000,
|
||||
c=color, s=2, alpha=0.8, label=state)
|
||||
|
||||
# Add stop-loss markers
|
||||
stop_loss_trades = trades_df[trades_df['action'] == 'STOP_LOSS']
|
||||
if not stop_loss_trades.empty:
|
||||
for _, trade in stop_loss_trades.iterrows():
|
||||
ax3.axvline(x=trade['date'], color='red', linestyle='--', alpha=0.8, linewidth=1)
|
||||
|
||||
ax3.set_title('Strategy State Timeline with Stop-Loss Events', fontsize=14, fontweight='bold', pad=25)
|
||||
ax3.set_ylabel('Total Assets ($K)', fontsize=11)
|
||||
ax3.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8)
|
||||
ax3.grid(True, alpha=0.3)
|
||||
ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
|
||||
for label in ax3.get_xticklabels():
|
||||
label.set_rotation(45)
|
||||
|
||||
# 4. Annual Returns Comparison (Full width)
|
||||
ax4 = plt.subplot(4, 1, 4)
|
||||
years = list(strategy_metrics['annual_returns'].keys())
|
||||
enhanced_returns = list(strategy_metrics['annual_returns'].values())
|
||||
qqq_returns = [qqq_metrics['annual_returns'].get(year, 0) for year in years]
|
||||
spy_returns = [spy_metrics['annual_returns'].get(year, 0) for year in years]
|
||||
|
||||
x = np.arange(len(years))
|
||||
width = 0.25
|
||||
|
||||
ax4.bar(x - width, enhanced_returns, width, label='Enhanced Strategy', color='red', alpha=0.8)
|
||||
ax4.bar(x, qqq_returns, width, label='QQQ', color='blue', alpha=0.8)
|
||||
ax4.bar(x + width, spy_returns, width, label='SPY', color='green', alpha=0.8)
|
||||
|
||||
ax4.set_title('Annual Returns Comparison by Year', fontsize=14, fontweight='bold', pad=25)
|
||||
ax4.set_ylabel('Annual Return (%)', fontsize=11)
|
||||
ax4.set_xlabel('Year', fontsize=11)
|
||||
ax4.set_xticks(x)
|
||||
ax4.set_xticklabels(years)
|
||||
for label in ax4.get_xticklabels():
|
||||
label.set_rotation(90) # Vertical text for years
|
||||
ax4.legend(fontsize=10)
|
||||
ax4.grid(True, alpha=0.3, axis='y')
|
||||
|
||||
plt.subplots_adjust(left=0.1, right=0.85, top=0.90, bottom=0.08, hspace=0.6)
|
||||
pdf.savefig(fig1, bbox_inches='tight', dpi=150)
|
||||
plt.close()
|
||||
|
||||
# Page 2: Detailed Analysis
|
||||
fig2 = plt.figure(figsize=(8.5, 11))
|
||||
fig2.suptitle('Detailed Trading and Risk Analysis', fontsize=16, fontweight='bold', y=0.95)
|
||||
|
||||
# 5. Stop-Loss Analysis (Full width)
|
||||
ax5 = plt.subplot(3, 1, 1)
|
||||
|
||||
if not stop_loss_trades.empty:
|
||||
stop_loss_trades_copy = stop_loss_trades.copy()
|
||||
stop_loss_trades_copy['year'] = stop_loss_trades_copy['date'].dt.year
|
||||
stop_loss_by_year = stop_loss_trades_copy.groupby('year').size()
|
||||
|
||||
ax5.bar(stop_loss_by_year.index, stop_loss_by_year.values, color='red', alpha=0.8, width=0.6)
|
||||
ax5.set_title('Stop-Loss Triggers by Year', fontsize=14, fontweight='bold', pad=30)
|
||||
ax5.set_ylabel('Number of Stop-Loss Events', fontsize=11)
|
||||
ax5.set_xlabel('Year', fontsize=11)
|
||||
ax5.grid(True, alpha=0.3, axis='y')
|
||||
for label in ax5.get_xticklabels():
|
||||
label.set_rotation(45)
|
||||
else:
|
||||
ax5.text(0.5, 0.5, 'No Stop-Loss Events Triggered', ha='center', va='center',
|
||||
transform=ax5.transAxes, fontsize=14, fontweight='bold')
|
||||
ax5.set_title('Stop-Loss Analysis', fontsize=14, fontweight='bold', pad=30)
|
||||
|
||||
# 6. Trade Frequency Analysis (Full width)
|
||||
ax6 = plt.subplot(3, 1, 2)
|
||||
|
||||
trades_df_copy = trades_df.copy()
|
||||
trades_df_copy['year'] = trades_df_copy['date'].dt.year
|
||||
trade_frequency = trades_df_copy.groupby('year').size()
|
||||
|
||||
ax6.bar(trade_frequency.index, trade_frequency.values, color='purple', alpha=0.8, width=0.6)
|
||||
ax6.set_title('Trading Activity by Year', fontsize=14, fontweight='bold', pad=30)
|
||||
ax6.set_ylabel('Number of Trades', fontsize=11)
|
||||
ax6.set_xlabel('Year', fontsize=11)
|
||||
ax6.grid(True, alpha=0.3, axis='y')
|
||||
for label in ax6.get_xticklabels():
|
||||
label.set_rotation(45)
|
||||
|
||||
# 7. Fear & Greed Index with Trading Signals (Full width)
|
||||
ax7 = plt.subplot(3, 1, 3)
|
||||
|
||||
# Plot Fear & Greed Index
|
||||
fg_data = portfolio_df['fg_index'].dropna()
|
||||
ax7.plot(fg_data.index, fg_data.values, color='purple', alpha=0.8, linewidth=1.5)
|
||||
ax7.axhline(y=25, color='red', linestyle='--', alpha=0.7, linewidth=2, label='Fear Threshold (25)')
|
||||
ax7.axhline(y=75, color='green', linestyle='--', alpha=0.7, linewidth=2, label='Greed Threshold (75)')
|
||||
ax7.fill_between(fg_data.index, 0, 25, alpha=0.2, color='red', label='Fear Zone')
|
||||
ax7.fill_between(fg_data.index, 75, 100, alpha=0.2, color='green', label='Greed Zone')
|
||||
|
||||
# Add trade markers
|
||||
buy_trades = trades_df[trades_df['action'].str.contains('BUY')]
|
||||
sell_trades = trades_df[trades_df['action'].str.contains('SELL')]
|
||||
|
||||
if not buy_trades.empty:
|
||||
ax7.scatter(buy_trades['date'], buy_trades['fg_index'],
|
||||
color='darkgreen', s=15, alpha=0.8, marker='^', label='Buy Signals', zorder=5)
|
||||
if not sell_trades.empty:
|
||||
ax7.scatter(sell_trades['date'], sell_trades['fg_index'],
|
||||
color='darkred', s=15, alpha=0.8, marker='v', label='Sell Signals', zorder=5)
|
||||
|
||||
ax7.set_title('Fear & Greed Index with Trading Signals', fontsize=14, fontweight='bold', pad=30)
|
||||
ax7.set_ylabel('CNN Fear & Greed Index', fontsize=11)
|
||||
ax7.set_xlabel('Date', fontsize=11)
|
||||
ax7.set_ylim(0, 100)
|
||||
ax7.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=8)
|
||||
ax7.grid(True, alpha=0.3)
|
||||
ax7.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
|
||||
for label in ax7.get_xticklabels():
|
||||
label.set_rotation(45)
|
||||
|
||||
plt.subplots_adjust(left=0.1, right=0.85, top=0.88, bottom=0.10, hspace=1.0)
|
||||
pdf.savefig(fig2, bbox_inches='tight', dpi=150)
|
||||
plt.close()
|
||||
|
||||
# Test PDF readability by checking file size and trying to read it
|
||||
try:
|
||||
file_size = os.path.getsize(pdf_file)
|
||||
if file_size < 50000: # Less than 50KB might indicate issues
|
||||
print(f"⚠️ Warning: PDF file size seems small ({file_size} bytes)")
|
||||
else:
|
||||
print(f"📈 PDF报告已保存: {pdf_file} (Size: {file_size:,} bytes)")
|
||||
|
||||
# Quick validation - ensure we can open the file
|
||||
with open(pdf_file, 'rb') as f:
|
||||
header = f.read(10)
|
||||
if not header.startswith(b'%PDF'):
|
||||
print(f"⚠️ Warning: Generated file may not be a valid PDF")
|
||||
else:
|
||||
print(f"✅ PDF file validation passed")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error validating PDF: {e}")
|
||||
|
||||
print(f"📈 PDF报告已保存: {pdf_file}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run strategy with config"""
|
||||
config_file = 'config.json'
|
||||
|
||||
if not os.path.exists(config_file):
|
||||
print(f"❌ 配置文件未找到: {config_file}")
|
||||
print("请确保config.json文件存在于当前目录")
|
||||
return
|
||||
|
||||
try:
|
||||
# Initialize and run strategy
|
||||
strategy = ConfigurableEmotionalDamageStrategy(config_file)
|
||||
strategy.run_backtest()
|
||||
|
||||
# Generate reports
|
||||
generate_reports(strategy)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 运行失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user