import json import os from datetime import datetime from typing import List, Dict, Optional, Any import uuid class ChatManager: def __init__(self, base_dir="chat_history"): self.base_dir = base_dir self.template_path = os.path.join(base_dir, "chat_template.html") self.sessions_dir = os.path.join(base_dir, "sessions") # Create sessions directory if it doesn't exist os.makedirs(self.sessions_dir, exist_ok=True) def create_session(self, title: str = None) -> str: """Create a new chat session and return session ID""" session_id = str(uuid.uuid4()) if not title: title = f"回测会话 {datetime.now().strftime('%Y-%m-%d %H:%M')}" session_data = { "session_id": session_id, "title": title, "created_at": datetime.now().isoformat(), "updated_at": datetime.now().isoformat(), "messages": [], "strategies": {}, "backtest_results": {} } self._save_session(session_id, session_data) return session_id def add_message(self, session_id: str, content: str, is_user: bool = True, strategy_info: Optional[Dict] = None, backtest_results: Optional[List[Dict]] = None): """Add a message to the chat session""" session_data = self._load_session(session_id) message = { "id": str(uuid.uuid4()), "content": content, "is_user": is_user, "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "strategy_info": strategy_info, "backtest_results": backtest_results } session_data["messages"].append(message) session_data["updated_at"] = datetime.now().isoformat() # If strategy info is provided, store it if strategy_info: session_data["strategies"][strategy_info["name"]] = strategy_info # If backtest results are provided, store them if backtest_results: result_id = f"result_{len(session_data['backtest_results'])}" session_data["backtest_results"][result_id] = { "timestamp": datetime.now().isoformat(), "results": backtest_results, "strategy": strategy_info["name"] if strategy_info else "Unknown" } self._save_session(session_id, session_data) def generate_html_report(self, session_id: str) -> str: """Generate HTML report for a session""" session_data = self._load_session(session_id) # Read HTML template with open(self.template_path, 'r', encoding='utf-8') as f: template = f.read() # Prepare template variables template_vars = { "session_title": session_data["title"], "created_at": datetime.fromisoformat(session_data["created_at"]).strftime("%Y-%m-%d %H:%M"), "message_count": len(session_data["messages"]), "strategy_count": len(session_data["strategies"]), "generated_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } # Generate messages HTML messages_html = "" for msg in session_data["messages"]: message_class = "user-message" if msg["is_user"] else "assistant-message" avatar_class = "user-avatar" if msg["is_user"] else "assistant-avatar" avatar_text = "U" if msg["is_user"] else "A" role_text = "用户" if msg["is_user"] else "助手" # Format content (basic markdown-like formatting) content = self._format_content(msg["content"]) message_html = f'''
{avatar_text}

{role_text}

{content}
''' # Add strategy info if present if msg.get("strategy_info"): strategy = msg["strategy_info"] message_html += f'''

策略信息: {strategy["name"]}

描述: {strategy.get("description", "无描述")}

参数: {json.dumps(strategy.get("parameters", {}), ensure_ascii=False, indent=2)}

''' # Add backtest results if present if msg.get("backtest_results"): message_html += '''

回测结果

''' for result in msg["backtest_results"]: message_html += f'''
{result["name"]}: {result["value"]}
''' message_html += '
' message_html += f'''
{msg["timestamp"]}
''' messages_html += message_html # Replace template variables html_content = template.replace("{{session_title}}", template_vars["session_title"]) html_content = html_content.replace("{{created_at}}", template_vars["created_at"]) html_content = html_content.replace("{{message_count}}", str(template_vars["message_count"])) html_content = html_content.replace("{{strategy_count}}", str(template_vars["strategy_count"])) html_content = html_content.replace("{{generated_at}}", template_vars["generated_at"]) html_content = html_content.replace("{{#messages}}{{/messages}}", messages_html) # Save HTML file html_filename = f"{session_id}_chat_report.html" html_path = os.path.join(self.sessions_dir, html_filename) with open(html_path, 'w', encoding='utf-8') as f: f.write(html_content) return html_path def _format_content(self, content: str) -> str: """Basic content formatting""" # Replace newlines with
content = content.replace('\n', '
') # Simple code block detection (lines starting with 4+ spaces) lines = content.split('
') formatted_lines = [] in_code_block = False for line in lines: if line.startswith(' ') or line.startswith('\t'): if not in_code_block: formatted_lines.append('
') in_code_block = True formatted_lines.append(line.strip()) else: if in_code_block: formatted_lines.append('
') in_code_block = False formatted_lines.append(line) if in_code_block: formatted_lines.append('') return '
'.join(formatted_lines) def _load_session(self, session_id: str) -> Dict: """Load session data from JSON file""" session_file = os.path.join(self.sessions_dir, f"{session_id}.json") if not os.path.exists(session_file): raise FileNotFoundError(f"Session {session_id} not found") with open(session_file, 'r', encoding='utf-8') as f: return json.load(f) def _save_session(self, session_id: str, session_data: Dict): """Save session data to JSON file""" session_file = os.path.join(self.sessions_dir, f"{session_id}.json") with open(session_file, 'w', encoding='utf-8') as f: json.dump(session_data, f, ensure_ascii=False, indent=2) def list_sessions(self) -> List[Dict]: """List all available sessions""" sessions = [] for filename in os.listdir(self.sessions_dir): if filename.endswith('.json'): session_id = filename[:-5] # Remove .json extension try: session_data = self._load_session(session_id) sessions.append({ "session_id": session_id, "title": session_data["title"], "created_at": session_data["created_at"], "updated_at": session_data["updated_at"], "message_count": len(session_data["messages"]), "strategy_count": len(session_data["strategies"]) }) except: continue # Sort by updated time, most recent first sessions.sort(key=lambda x: x["updated_at"], reverse=True) return sessions # Example usage and test if __name__ == "__main__": chat_manager = ChatManager() # Create a test session session_id = chat_manager.create_session("测试回测策略") # Add some sample messages chat_manager.add_message( session_id, "你好,我想测试一个移动平均策略", is_user=True ) chat_manager.add_message( session_id, "好的,我来帮你实现一个简单的移动平均策略。", is_user=False, strategy_info={ "name": "简单移动平均策略", "description": "基于短期和长期移动平均线的交叉信号", "parameters": { "short_window": 20, "long_window": 50, "symbol": "AAPL" } } ) chat_manager.add_message( session_id, "策略回测完成,以下是结果:", is_user=False, backtest_results=[ {"name": "总收益率", "value": "15.6%"}, {"name": "年化收益率", "value": "12.3%"}, {"name": "最大回撤", "value": "-8.2%"}, {"name": "夏普比率", "value": "1.45"} ] ) # Generate HTML report html_path = chat_manager.generate_html_report(session_id) print(f"HTML report generated: {html_path}") # List sessions sessions = chat_manager.list_sessions() print(f"Found {len(sessions)} sessions") for session in sessions: print(f"- {session['title']} ({session['message_count']} messages)")