272 lines
10 KiB
Python
Executable File
272 lines
10 KiB
Python
Executable File
import json
|
|
import os
|
|
from datetime import datetime
|
|
from typing import List, Dict, Optional, Any
|
|
import uuid
|
|
|
|
class ChatManager:
|
|
def __init__(self, base_dir="chat_history"):
|
|
self.base_dir = base_dir
|
|
self.template_path = os.path.join(base_dir, "chat_template.html")
|
|
self.sessions_dir = os.path.join(base_dir, "sessions")
|
|
|
|
# Create sessions directory if it doesn't exist
|
|
os.makedirs(self.sessions_dir, exist_ok=True)
|
|
|
|
def create_session(self, title: str = None) -> str:
|
|
"""Create a new chat session and return session ID"""
|
|
session_id = str(uuid.uuid4())
|
|
if not title:
|
|
title = f"回测会话 {datetime.now().strftime('%Y-%m-%d %H:%M')}"
|
|
|
|
session_data = {
|
|
"session_id": session_id,
|
|
"title": title,
|
|
"created_at": datetime.now().isoformat(),
|
|
"updated_at": datetime.now().isoformat(),
|
|
"messages": [],
|
|
"strategies": {},
|
|
"backtest_results": {}
|
|
}
|
|
|
|
self._save_session(session_id, session_data)
|
|
return session_id
|
|
|
|
def add_message(self, session_id: str, content: str, is_user: bool = True,
|
|
strategy_info: Optional[Dict] = None,
|
|
backtest_results: Optional[List[Dict]] = None):
|
|
"""Add a message to the chat session"""
|
|
session_data = self._load_session(session_id)
|
|
|
|
message = {
|
|
"id": str(uuid.uuid4()),
|
|
"content": content,
|
|
"is_user": is_user,
|
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
|
"strategy_info": strategy_info,
|
|
"backtest_results": backtest_results
|
|
}
|
|
|
|
session_data["messages"].append(message)
|
|
session_data["updated_at"] = datetime.now().isoformat()
|
|
|
|
# If strategy info is provided, store it
|
|
if strategy_info:
|
|
session_data["strategies"][strategy_info["name"]] = strategy_info
|
|
|
|
# If backtest results are provided, store them
|
|
if backtest_results:
|
|
result_id = f"result_{len(session_data['backtest_results'])}"
|
|
session_data["backtest_results"][result_id] = {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"results": backtest_results,
|
|
"strategy": strategy_info["name"] if strategy_info else "Unknown"
|
|
}
|
|
|
|
self._save_session(session_id, session_data)
|
|
|
|
def generate_html_report(self, session_id: str) -> str:
|
|
"""Generate HTML report for a session"""
|
|
session_data = self._load_session(session_id)
|
|
|
|
# Read HTML template
|
|
with open(self.template_path, 'r', encoding='utf-8') as f:
|
|
template = f.read()
|
|
|
|
# Prepare template variables
|
|
template_vars = {
|
|
"session_title": session_data["title"],
|
|
"created_at": datetime.fromisoformat(session_data["created_at"]).strftime("%Y-%m-%d %H:%M"),
|
|
"message_count": len(session_data["messages"]),
|
|
"strategy_count": len(session_data["strategies"]),
|
|
"generated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
# Generate messages HTML
|
|
messages_html = ""
|
|
for msg in session_data["messages"]:
|
|
message_class = "user-message" if msg["is_user"] else "assistant-message"
|
|
avatar_class = "user-avatar" if msg["is_user"] else "assistant-avatar"
|
|
avatar_text = "U" if msg["is_user"] else "A"
|
|
role_text = "用户" if msg["is_user"] else "助手"
|
|
|
|
# Format content (basic markdown-like formatting)
|
|
content = self._format_content(msg["content"])
|
|
|
|
message_html = f'''
|
|
<div class="message {message_class}">
|
|
<div class="avatar {avatar_class}">{avatar_text}</div>
|
|
<div class="message-content">
|
|
<h3>{role_text}</h3>
|
|
<div class="message-text">{content}</div>
|
|
'''
|
|
|
|
# Add strategy info if present
|
|
if msg.get("strategy_info"):
|
|
strategy = msg["strategy_info"]
|
|
message_html += f'''
|
|
<div class="strategy-info">
|
|
<h4>策略信息: {strategy["name"]}</h4>
|
|
<p><strong>描述:</strong> {strategy.get("description", "无描述")}</p>
|
|
<p><strong>参数:</strong> {json.dumps(strategy.get("parameters", {}), ensure_ascii=False, indent=2)}</p>
|
|
</div>
|
|
'''
|
|
|
|
# Add backtest results if present
|
|
if msg.get("backtest_results"):
|
|
message_html += '''
|
|
<div class="backtest-results">
|
|
<h4>回测结果</h4>
|
|
'''
|
|
for result in msg["backtest_results"]:
|
|
message_html += f'''
|
|
<div class="metric">
|
|
<span>{result["name"]}:</span>
|
|
<span>{result["value"]}</span>
|
|
</div>
|
|
'''
|
|
message_html += '</div>'
|
|
|
|
message_html += f'''
|
|
<div class="timestamp">{msg["timestamp"]}</div>
|
|
</div>
|
|
</div>
|
|
'''
|
|
|
|
messages_html += message_html
|
|
|
|
# Replace template variables
|
|
html_content = template.replace("{{session_title}}", template_vars["session_title"])
|
|
html_content = html_content.replace("{{created_at}}", template_vars["created_at"])
|
|
html_content = html_content.replace("{{message_count}}", str(template_vars["message_count"]))
|
|
html_content = html_content.replace("{{strategy_count}}", str(template_vars["strategy_count"]))
|
|
html_content = html_content.replace("{{generated_at}}", template_vars["generated_at"])
|
|
html_content = html_content.replace("{{#messages}}{{/messages}}", messages_html)
|
|
|
|
# Save HTML file
|
|
html_filename = f"{session_id}_chat_report.html"
|
|
html_path = os.path.join(self.sessions_dir, html_filename)
|
|
|
|
with open(html_path, 'w', encoding='utf-8') as f:
|
|
f.write(html_content)
|
|
|
|
return html_path
|
|
|
|
def _format_content(self, content: str) -> str:
|
|
"""Basic content formatting"""
|
|
# Replace newlines with <br>
|
|
content = content.replace('\n', '<br>')
|
|
|
|
# Simple code block detection (lines starting with 4+ spaces)
|
|
lines = content.split('<br>')
|
|
formatted_lines = []
|
|
in_code_block = False
|
|
|
|
for line in lines:
|
|
if line.startswith(' ') or line.startswith('\t'):
|
|
if not in_code_block:
|
|
formatted_lines.append('<div class="code-block">')
|
|
in_code_block = True
|
|
formatted_lines.append(line.strip())
|
|
else:
|
|
if in_code_block:
|
|
formatted_lines.append('</div>')
|
|
in_code_block = False
|
|
formatted_lines.append(line)
|
|
|
|
if in_code_block:
|
|
formatted_lines.append('</div>')
|
|
|
|
return '<br>'.join(formatted_lines)
|
|
|
|
def _load_session(self, session_id: str) -> Dict:
|
|
"""Load session data from JSON file"""
|
|
session_file = os.path.join(self.sessions_dir, f"{session_id}.json")
|
|
|
|
if not os.path.exists(session_file):
|
|
raise FileNotFoundError(f"Session {session_id} not found")
|
|
|
|
with open(session_file, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
|
|
def _save_session(self, session_id: str, session_data: Dict):
|
|
"""Save session data to JSON file"""
|
|
session_file = os.path.join(self.sessions_dir, f"{session_id}.json")
|
|
|
|
with open(session_file, 'w', encoding='utf-8') as f:
|
|
json.dump(session_data, f, ensure_ascii=False, indent=2)
|
|
|
|
def list_sessions(self) -> List[Dict]:
|
|
"""List all available sessions"""
|
|
sessions = []
|
|
|
|
for filename in os.listdir(self.sessions_dir):
|
|
if filename.endswith('.json'):
|
|
session_id = filename[:-5] # Remove .json extension
|
|
try:
|
|
session_data = self._load_session(session_id)
|
|
sessions.append({
|
|
"session_id": session_id,
|
|
"title": session_data["title"],
|
|
"created_at": session_data["created_at"],
|
|
"updated_at": session_data["updated_at"],
|
|
"message_count": len(session_data["messages"]),
|
|
"strategy_count": len(session_data["strategies"])
|
|
})
|
|
except:
|
|
continue
|
|
|
|
# Sort by updated time, most recent first
|
|
sessions.sort(key=lambda x: x["updated_at"], reverse=True)
|
|
return sessions
|
|
|
|
# Example usage and test
|
|
if __name__ == "__main__":
|
|
chat_manager = ChatManager()
|
|
|
|
# Create a test session
|
|
session_id = chat_manager.create_session("测试回测策略")
|
|
|
|
# Add some sample messages
|
|
chat_manager.add_message(
|
|
session_id,
|
|
"你好,我想测试一个移动平均策略",
|
|
is_user=True
|
|
)
|
|
|
|
chat_manager.add_message(
|
|
session_id,
|
|
"好的,我来帮你实现一个简单的移动平均策略。",
|
|
is_user=False,
|
|
strategy_info={
|
|
"name": "简单移动平均策略",
|
|
"description": "基于短期和长期移动平均线的交叉信号",
|
|
"parameters": {
|
|
"short_window": 20,
|
|
"long_window": 50,
|
|
"symbol": "AAPL"
|
|
}
|
|
}
|
|
)
|
|
|
|
chat_manager.add_message(
|
|
session_id,
|
|
"策略回测完成,以下是结果:",
|
|
is_user=False,
|
|
backtest_results=[
|
|
{"name": "总收益率", "value": "15.6%"},
|
|
{"name": "年化收益率", "value": "12.3%"},
|
|
{"name": "最大回撤", "value": "-8.2%"},
|
|
{"name": "夏普比率", "value": "1.45"}
|
|
]
|
|
)
|
|
|
|
# Generate HTML report
|
|
html_path = chat_manager.generate_html_report(session_id)
|
|
print(f"HTML report generated: {html_path}")
|
|
|
|
# List sessions
|
|
sessions = chat_manager.list_sessions()
|
|
print(f"Found {len(sessions)} sessions")
|
|
for session in sessions:
|
|
print(f"- {session['title']} ({session['message_count']} messages)") |