Files
docker-configs/backtest/calculate_indicators.py
2025-07-18 00:00:01 -05:00

157 lines
4.9 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sqlite3
import pandas as pd
import numpy as np
from tqdm import tqdm
def calculate_sma(data, window):
"""计算简单移动平均线"""
return data.rolling(window=window).mean()
def calculate_rsi(data, window=14):
"""计算RSI指标"""
delta = data.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=window).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=window).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
return rsi
def calculate_bollinger_bands(data, window=20, num_std=2):
"""计算布林通道"""
sma = calculate_sma(data, window)
rolling_std = data.rolling(window=window).std()
upper_band = sma + (rolling_std * num_std)
lower_band = sma - (rolling_std * num_std)
return upper_band, sma, lower_band
def calculate_macd(data, fast=12, slow=26, signal=9):
"""计算MACD指标"""
ema_fast = data.ewm(span=fast).mean()
ema_slow = data.ewm(span=slow).mean()
macd_line = ema_fast - ema_slow
signal_line = macd_line.ewm(span=signal).mean()
histogram = macd_line - signal_line
return macd_line, signal_line, histogram
def calculate_volatility(data, window=20):
"""计算波动率(标准差)"""
returns = data.pct_change()
volatility = returns.rolling(window=window).std() * np.sqrt(252) # 年化波动率
return volatility
def add_indicators_to_ticker(ticker):
"""为单个ticker计算所有技术指标"""
conn = sqlite3.connect('data/stock_data.db')
# 读取数据
df = pd.read_sql_query(f'SELECT * FROM {ticker.lower()} ORDER BY date', conn)
df['date'] = pd.to_datetime(df['date'])
df.set_index('date', inplace=True)
# 计算各种指标
df['sma_5'] = calculate_sma(df['close'], 5)
df['sma_20'] = calculate_sma(df['close'], 20)
df['sma_200'] = calculate_sma(df['close'], 200)
df['rsi'] = calculate_rsi(df['close'])
df['bb_upper'], df['bb_middle'], df['bb_lower'] = calculate_bollinger_bands(df['close'])
df['macd'], df['macd_signal'], df['macd_histogram'] = calculate_macd(df['close'])
df['volatility'] = calculate_volatility(df['close'])
# 删除原表
cursor = conn.cursor()
cursor.execute(f'DROP TABLE IF EXISTS {ticker.lower()}')
# 创建新表结构(包含指标)
cursor.execute(f'''
CREATE TABLE {ticker.lower()} (
date DATE PRIMARY KEY,
open REAL,
high REAL,
low REAL,
close REAL,
volume INTEGER,
sma_5 REAL,
sma_20 REAL,
sma_200 REAL,
rsi REAL,
bb_upper REAL,
bb_middle REAL,
bb_lower REAL,
macd REAL,
macd_signal REAL,
macd_histogram REAL,
volatility REAL
)
''')
# 重置索引并插入数据
df.reset_index(inplace=True)
for _, row in df.iterrows():
cursor.execute(f'''
INSERT INTO {ticker.lower()}
(date, open, high, low, close, volume, sma_5, sma_20, sma_200,
rsi, bb_upper, bb_middle, bb_lower, macd, macd_signal, macd_histogram, volatility)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
row['date'].strftime('%Y-%m-%d'),
row['open'], row['high'], row['low'], row['close'], row['volume'],
row['sma_5'], row['sma_20'], row['sma_200'],
row['rsi'], row['bb_upper'], row['bb_middle'], row['bb_lower'],
row['macd'], row['macd_signal'], row['macd_histogram'],
row['volatility']
))
conn.commit()
conn.close()
return len(df)
def main():
print("开始为所有ticker计算技术指标...")
# 获取所有ticker
conn = sqlite3.connect('data/stock_data.db')
cursor = conn.cursor()
cursor.execute('SELECT ticker FROM ticker_list ORDER BY ticker')
tickers = [row[0] for row in cursor.fetchall()]
conn.close()
print(f"找到 {len(tickers)} 个tickers")
total_records = 0
for ticker in tqdm(tickers, desc="计算技术指标"):
try:
records = add_indicators_to_ticker(ticker)
total_records += records
print(f" {ticker}: {records} records with indicators")
except Exception as e:
print(f" {ticker}: Error - {e}")
print(f"\n完成!总共处理了 {total_records} 条记录")
# 验证结果
print("\n验证指标计算结果...")
conn = sqlite3.connect('data/stock_data.db')
# 检查AAPL的最新数据
df_sample = pd.read_sql_query('''
SELECT date, close, sma_20, rsi, bb_upper, bb_lower, macd, volatility
FROM aapl
ORDER BY date DESC
LIMIT 5
''', conn)
print("AAPL最新5天数据包含指标:")
print(df_sample.to_string(index=False))
conn.close()
if __name__ == "__main__":
main()