157 lines
4.9 KiB
Python
Executable File
157 lines
4.9 KiB
Python
Executable File
import sqlite3
|
||
import pandas as pd
|
||
import numpy as np
|
||
from tqdm import tqdm
|
||
|
||
def calculate_sma(data, window):
|
||
"""计算简单移动平均线"""
|
||
return data.rolling(window=window).mean()
|
||
|
||
def calculate_rsi(data, window=14):
|
||
"""计算RSI指标"""
|
||
delta = data.diff()
|
||
gain = (delta.where(delta > 0, 0)).rolling(window=window).mean()
|
||
loss = (-delta.where(delta < 0, 0)).rolling(window=window).mean()
|
||
rs = gain / loss
|
||
rsi = 100 - (100 / (1 + rs))
|
||
return rsi
|
||
|
||
def calculate_bollinger_bands(data, window=20, num_std=2):
|
||
"""计算布林通道"""
|
||
sma = calculate_sma(data, window)
|
||
rolling_std = data.rolling(window=window).std()
|
||
upper_band = sma + (rolling_std * num_std)
|
||
lower_band = sma - (rolling_std * num_std)
|
||
return upper_band, sma, lower_band
|
||
|
||
def calculate_macd(data, fast=12, slow=26, signal=9):
|
||
"""计算MACD指标"""
|
||
ema_fast = data.ewm(span=fast).mean()
|
||
ema_slow = data.ewm(span=slow).mean()
|
||
macd_line = ema_fast - ema_slow
|
||
signal_line = macd_line.ewm(span=signal).mean()
|
||
histogram = macd_line - signal_line
|
||
return macd_line, signal_line, histogram
|
||
|
||
def calculate_volatility(data, window=20):
|
||
"""计算波动率(标准差)"""
|
||
returns = data.pct_change()
|
||
volatility = returns.rolling(window=window).std() * np.sqrt(252) # 年化波动率
|
||
return volatility
|
||
|
||
def add_indicators_to_ticker(ticker):
|
||
"""为单个ticker计算所有技术指标"""
|
||
|
||
conn = sqlite3.connect('data/stock_data.db')
|
||
|
||
# 读取数据
|
||
df = pd.read_sql_query(f'SELECT * FROM {ticker.lower()} ORDER BY date', conn)
|
||
df['date'] = pd.to_datetime(df['date'])
|
||
df.set_index('date', inplace=True)
|
||
|
||
# 计算各种指标
|
||
df['sma_5'] = calculate_sma(df['close'], 5)
|
||
df['sma_20'] = calculate_sma(df['close'], 20)
|
||
df['sma_200'] = calculate_sma(df['close'], 200)
|
||
|
||
df['rsi'] = calculate_rsi(df['close'])
|
||
|
||
df['bb_upper'], df['bb_middle'], df['bb_lower'] = calculate_bollinger_bands(df['close'])
|
||
|
||
df['macd'], df['macd_signal'], df['macd_histogram'] = calculate_macd(df['close'])
|
||
|
||
df['volatility'] = calculate_volatility(df['close'])
|
||
|
||
# 删除原表
|
||
cursor = conn.cursor()
|
||
cursor.execute(f'DROP TABLE IF EXISTS {ticker.lower()}')
|
||
|
||
# 创建新表结构(包含指标)
|
||
cursor.execute(f'''
|
||
CREATE TABLE {ticker.lower()} (
|
||
date DATE PRIMARY KEY,
|
||
open REAL,
|
||
high REAL,
|
||
low REAL,
|
||
close REAL,
|
||
volume INTEGER,
|
||
sma_5 REAL,
|
||
sma_20 REAL,
|
||
sma_200 REAL,
|
||
rsi REAL,
|
||
bb_upper REAL,
|
||
bb_middle REAL,
|
||
bb_lower REAL,
|
||
macd REAL,
|
||
macd_signal REAL,
|
||
macd_histogram REAL,
|
||
volatility REAL
|
||
)
|
||
''')
|
||
|
||
# 重置索引并插入数据
|
||
df.reset_index(inplace=True)
|
||
|
||
for _, row in df.iterrows():
|
||
cursor.execute(f'''
|
||
INSERT INTO {ticker.lower()}
|
||
(date, open, high, low, close, volume, sma_5, sma_20, sma_200,
|
||
rsi, bb_upper, bb_middle, bb_lower, macd, macd_signal, macd_histogram, volatility)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
''', (
|
||
row['date'].strftime('%Y-%m-%d'),
|
||
row['open'], row['high'], row['low'], row['close'], row['volume'],
|
||
row['sma_5'], row['sma_20'], row['sma_200'],
|
||
row['rsi'], row['bb_upper'], row['bb_middle'], row['bb_lower'],
|
||
row['macd'], row['macd_signal'], row['macd_histogram'],
|
||
row['volatility']
|
||
))
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
return len(df)
|
||
|
||
def main():
|
||
print("开始为所有ticker计算技术指标...")
|
||
|
||
# 获取所有ticker
|
||
conn = sqlite3.connect('data/stock_data.db')
|
||
cursor = conn.cursor()
|
||
cursor.execute('SELECT ticker FROM ticker_list ORDER BY ticker')
|
||
tickers = [row[0] for row in cursor.fetchall()]
|
||
conn.close()
|
||
|
||
print(f"找到 {len(tickers)} 个tickers")
|
||
|
||
total_records = 0
|
||
|
||
for ticker in tqdm(tickers, desc="计算技术指标"):
|
||
try:
|
||
records = add_indicators_to_ticker(ticker)
|
||
total_records += records
|
||
print(f" {ticker}: {records} records with indicators")
|
||
except Exception as e:
|
||
print(f" {ticker}: Error - {e}")
|
||
|
||
print(f"\n完成!总共处理了 {total_records} 条记录")
|
||
|
||
# 验证结果
|
||
print("\n验证指标计算结果...")
|
||
conn = sqlite3.connect('data/stock_data.db')
|
||
|
||
# 检查AAPL的最新数据
|
||
df_sample = pd.read_sql_query('''
|
||
SELECT date, close, sma_20, rsi, bb_upper, bb_lower, macd, volatility
|
||
FROM aapl
|
||
ORDER BY date DESC
|
||
LIMIT 5
|
||
''', conn)
|
||
|
||
print("AAPL最新5天数据(包含指标):")
|
||
print(df_sample.to_string(index=False))
|
||
|
||
conn.close()
|
||
|
||
if __name__ == "__main__":
|
||
main() |