Files
docker-configs/backtest/chat_history/chat_manager.py
2025-07-18 00:00:01 -05:00

272 lines
10 KiB
Python
Executable File

import json
import os
from datetime import datetime
from typing import List, Dict, Optional, Any
import uuid
class ChatManager:
def __init__(self, base_dir="chat_history"):
self.base_dir = base_dir
self.template_path = os.path.join(base_dir, "chat_template.html")
self.sessions_dir = os.path.join(base_dir, "sessions")
# Create sessions directory if it doesn't exist
os.makedirs(self.sessions_dir, exist_ok=True)
def create_session(self, title: str = None) -> str:
"""Create a new chat session and return session ID"""
session_id = str(uuid.uuid4())
if not title:
title = f"回测会话 {datetime.now().strftime('%Y-%m-%d %H:%M')}"
session_data = {
"session_id": session_id,
"title": title,
"created_at": datetime.now().isoformat(),
"updated_at": datetime.now().isoformat(),
"messages": [],
"strategies": {},
"backtest_results": {}
}
self._save_session(session_id, session_data)
return session_id
def add_message(self, session_id: str, content: str, is_user: bool = True,
strategy_info: Optional[Dict] = None,
backtest_results: Optional[List[Dict]] = None):
"""Add a message to the chat session"""
session_data = self._load_session(session_id)
message = {
"id": str(uuid.uuid4()),
"content": content,
"is_user": is_user,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"strategy_info": strategy_info,
"backtest_results": backtest_results
}
session_data["messages"].append(message)
session_data["updated_at"] = datetime.now().isoformat()
# If strategy info is provided, store it
if strategy_info:
session_data["strategies"][strategy_info["name"]] = strategy_info
# If backtest results are provided, store them
if backtest_results:
result_id = f"result_{len(session_data['backtest_results'])}"
session_data["backtest_results"][result_id] = {
"timestamp": datetime.now().isoformat(),
"results": backtest_results,
"strategy": strategy_info["name"] if strategy_info else "Unknown"
}
self._save_session(session_id, session_data)
def generate_html_report(self, session_id: str) -> str:
"""Generate HTML report for a session"""
session_data = self._load_session(session_id)
# Read HTML template
with open(self.template_path, 'r', encoding='utf-8') as f:
template = f.read()
# Prepare template variables
template_vars = {
"session_title": session_data["title"],
"created_at": datetime.fromisoformat(session_data["created_at"]).strftime("%Y-%m-%d %H:%M"),
"message_count": len(session_data["messages"]),
"strategy_count": len(session_data["strategies"]),
"generated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
# Generate messages HTML
messages_html = ""
for msg in session_data["messages"]:
message_class = "user-message" if msg["is_user"] else "assistant-message"
avatar_class = "user-avatar" if msg["is_user"] else "assistant-avatar"
avatar_text = "U" if msg["is_user"] else "A"
role_text = "用户" if msg["is_user"] else "助手"
# Format content (basic markdown-like formatting)
content = self._format_content(msg["content"])
message_html = f'''
<div class="message {message_class}">
<div class="avatar {avatar_class}">{avatar_text}</div>
<div class="message-content">
<h3>{role_text}</h3>
<div class="message-text">{content}</div>
'''
# Add strategy info if present
if msg.get("strategy_info"):
strategy = msg["strategy_info"]
message_html += f'''
<div class="strategy-info">
<h4>策略信息: {strategy["name"]}</h4>
<p><strong>描述:</strong> {strategy.get("description", "无描述")}</p>
<p><strong>参数:</strong> {json.dumps(strategy.get("parameters", {}), ensure_ascii=False, indent=2)}</p>
</div>
'''
# Add backtest results if present
if msg.get("backtest_results"):
message_html += '''
<div class="backtest-results">
<h4>回测结果</h4>
'''
for result in msg["backtest_results"]:
message_html += f'''
<div class="metric">
<span>{result["name"]}:</span>
<span>{result["value"]}</span>
</div>
'''
message_html += '</div>'
message_html += f'''
<div class="timestamp">{msg["timestamp"]}</div>
</div>
</div>
'''
messages_html += message_html
# Replace template variables
html_content = template.replace("{{session_title}}", template_vars["session_title"])
html_content = html_content.replace("{{created_at}}", template_vars["created_at"])
html_content = html_content.replace("{{message_count}}", str(template_vars["message_count"]))
html_content = html_content.replace("{{strategy_count}}", str(template_vars["strategy_count"]))
html_content = html_content.replace("{{generated_at}}", template_vars["generated_at"])
html_content = html_content.replace("{{#messages}}{{/messages}}", messages_html)
# Save HTML file
html_filename = f"{session_id}_chat_report.html"
html_path = os.path.join(self.sessions_dir, html_filename)
with open(html_path, 'w', encoding='utf-8') as f:
f.write(html_content)
return html_path
def _format_content(self, content: str) -> str:
"""Basic content formatting"""
# Replace newlines with <br>
content = content.replace('\n', '<br>')
# Simple code block detection (lines starting with 4+ spaces)
lines = content.split('<br>')
formatted_lines = []
in_code_block = False
for line in lines:
if line.startswith(' ') or line.startswith('\t'):
if not in_code_block:
formatted_lines.append('<div class="code-block">')
in_code_block = True
formatted_lines.append(line.strip())
else:
if in_code_block:
formatted_lines.append('</div>')
in_code_block = False
formatted_lines.append(line)
if in_code_block:
formatted_lines.append('</div>')
return '<br>'.join(formatted_lines)
def _load_session(self, session_id: str) -> Dict:
"""Load session data from JSON file"""
session_file = os.path.join(self.sessions_dir, f"{session_id}.json")
if not os.path.exists(session_file):
raise FileNotFoundError(f"Session {session_id} not found")
with open(session_file, 'r', encoding='utf-8') as f:
return json.load(f)
def _save_session(self, session_id: str, session_data: Dict):
"""Save session data to JSON file"""
session_file = os.path.join(self.sessions_dir, f"{session_id}.json")
with open(session_file, 'w', encoding='utf-8') as f:
json.dump(session_data, f, ensure_ascii=False, indent=2)
def list_sessions(self) -> List[Dict]:
"""List all available sessions"""
sessions = []
for filename in os.listdir(self.sessions_dir):
if filename.endswith('.json'):
session_id = filename[:-5] # Remove .json extension
try:
session_data = self._load_session(session_id)
sessions.append({
"session_id": session_id,
"title": session_data["title"],
"created_at": session_data["created_at"],
"updated_at": session_data["updated_at"],
"message_count": len(session_data["messages"]),
"strategy_count": len(session_data["strategies"])
})
except:
continue
# Sort by updated time, most recent first
sessions.sort(key=lambda x: x["updated_at"], reverse=True)
return sessions
# Example usage and test
if __name__ == "__main__":
chat_manager = ChatManager()
# Create a test session
session_id = chat_manager.create_session("测试回测策略")
# Add some sample messages
chat_manager.add_message(
session_id,
"你好,我想测试一个移动平均策略",
is_user=True
)
chat_manager.add_message(
session_id,
"好的,我来帮你实现一个简单的移动平均策略。",
is_user=False,
strategy_info={
"name": "简单移动平均策略",
"description": "基于短期和长期移动平均线的交叉信号",
"parameters": {
"short_window": 20,
"long_window": 50,
"symbol": "AAPL"
}
}
)
chat_manager.add_message(
session_id,
"策略回测完成,以下是结果:",
is_user=False,
backtest_results=[
{"name": "总收益率", "value": "15.6%"},
{"name": "年化收益率", "value": "12.3%"},
{"name": "最大回撤", "value": "-8.2%"},
{"name": "夏普比率", "value": "1.45"}
]
)
# Generate HTML report
html_path = chat_manager.generate_html_report(session_id)
print(f"HTML report generated: {html_path}")
# List sessions
sessions = chat_manager.list_sessions()
print(f"Found {len(sessions)} sessions")
for session in sessions:
print(f"- {session['title']} ({session['message_count']} messages)")