import asyncio import json import logging from datetime import datetime from typing import Dict, Any, Optional from contextlib import asynccontextmanager import httpx from fastapi import FastAPI, Request, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from anthropic import Anthropic from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger from config import config # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ClaudeRouter: def __init__(self): self.current_provider = "claude_pro" self.failover_count = 0 self.last_failover = None self.last_health_check = None self.health_check_failures = 0 self.scheduler = None self.providers = { "claude_pro": { "api_key": config.claude_pro_api_key, "base_url": config.claude_pro_base_url, "active": True }, "claude_api": { "api_key": config.claude_api_key, "base_url": config.claude_api_base_url, "active": True } } async def get_anthropic_client(self, provider: str) -> Anthropic: """Get Anthropic client for the specified provider""" if provider not in self.providers: raise ValueError(f"Unknown provider: {provider}") provider_config = self.providers[provider] return Anthropic( api_key=provider_config["api_key"], base_url=provider_config["base_url"] ) async def should_failover(self, error: Exception) -> bool: """Determine if we should failover based on the error""" error_str = str(error).lower() # Check for rate limiting or usage limit errors failover_indicators = [ "rate_limit", "usage limit", "quota exceeded", "429", "too many requests", "limit reached" ] return any(indicator in error_str for indicator in failover_indicators) async def failover_to_next_provider(self): """Switch to the next available provider""" providers_list = list(self.providers.keys()) current_index = providers_list.index(self.current_provider) # Try next provider for i in range(1, len(providers_list)): next_index = (current_index + i) % len(providers_list) next_provider = providers_list[next_index] if self.providers[next_provider]["active"]: logger.info(f"Failing over from {self.current_provider} to {next_provider}") self.current_provider = next_provider self.failover_count += 1 self.last_failover = datetime.now() return True logger.error("No active providers available for failover") return False async def make_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Make request with automatic failover""" max_attempts = len(self.providers) for attempt in range(max_attempts): try: client = await self.get_anthropic_client(self.current_provider) # Extract parameters from request messages = request_data.get("messages", []) model = request_data.get("model", "claude-3-sonnet-20240229") max_tokens = request_data.get("max_tokens", 4096) stream = request_data.get("stream", False) logger.info(f"Making request with provider: {self.current_provider}") # Make the API call if hasattr(client, 'messages'): response = await asyncio.to_thread( client.messages.create, model=model, max_tokens=max_tokens, messages=messages, stream=stream ) else: # For older anthropic versions response = await asyncio.to_thread( client.completions.create, model=model, max_tokens_to_sample=max_tokens, prompt=f"Human: {messages[0]['content']}\n\nAssistant:", stream=stream ) return response except Exception as e: logger.error(f"Request failed with {self.current_provider}: {str(e)}") if await self.should_failover(e) and attempt < max_attempts - 1: if await self.failover_to_next_provider(): continue # If this is the last attempt or failover failed, raise the error if attempt == max_attempts - 1: raise HTTPException(status_code=500, detail=f"All providers failed. Last error: {str(e)}") raise HTTPException(status_code=500, detail="No providers available") async def health_check_claude_pro(self): """Check if Claude Pro is available again""" # Only check if we're not currently using Claude Pro if self.current_provider == "claude_pro": logger.debug("Skipping health check - already using Claude Pro") return logger.info("Running Claude Pro health check...") self.last_health_check = datetime.now() try: client = Anthropic( api_key=config.claude_pro_api_key, base_url=config.claude_pro_base_url ) # Send a minimal test message if hasattr(client, 'messages'): response = await asyncio.to_thread( client.messages.create, model=config.health_check_model, max_tokens=10, messages=[{"role": "user", "content": config.health_check_message}] ) else: # For older anthropic versions response = await asyncio.to_thread( client.completions.create, model=config.health_check_model, max_tokens_to_sample=10, prompt=f"Human: {config.health_check_message}\n\nAssistant:" ) # If successful, switch back to Claude Pro old_provider = self.current_provider self.current_provider = "claude_pro" self.health_check_failures = 0 logger.info(f"Claude Pro health check successful! Switched from {old_provider} to claude_pro") except Exception as e: self.health_check_failures += 1 error_str = str(e).lower() if any(indicator in error_str for indicator in ["rate_limit", "usage limit", "quota exceeded", "429", "too many requests", "limit reached"]): logger.info(f"Claude Pro still rate limited: {str(e)}") else: logger.warning(f"Claude Pro health check failed (attempt {self.health_check_failures}): {str(e)}") def start_scheduler(self): """Start the health check scheduler""" if not config.health_check_enabled: logger.info("Health check disabled in config") return self.scheduler = AsyncIOScheduler() # Schedule health check using cron expression self.scheduler.add_job( self.health_check_claude_pro, trigger=CronTrigger.from_crontab(config.health_check_cron), id="claude_pro_health_check", name="Claude Pro Health Check", misfire_grace_time=60 ) self.scheduler.start() logger.info(f"Health check scheduler started with cron: {config.health_check_cron}") def stop_scheduler(self): """Stop the health check scheduler""" if self.scheduler: self.scheduler.shutdown() logger.info("Health check scheduler stopped") # Initialize router router = ClaudeRouter() @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Claude Router starting up...") logger.info(f"Current provider: {router.current_provider}") # Start health check scheduler router.start_scheduler() yield # Stop scheduler on shutdown router.stop_scheduler() logger.info("Claude Router shutting down...") app = FastAPI( title="Claude Router", description="Smart router for Claude API with automatic failover", version="1.0.0", lifespan=lifespan ) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "current_provider": router.current_provider, "failover_count": router.failover_count, "last_failover": router.last_failover.isoformat() if router.last_failover else None, "providers": { name: {"active": provider_config["active"]} for name, provider_config in router.providers.items() }, "last_health_check": router.last_health_check.isoformat() if router.last_health_check else None, "health_check_failures": router.health_check_failures } @app.post("/v1/messages") async def create_message(request: Request): """Handle Claude API message creation with failover""" try: request_data = await request.json() stream = request_data.get("stream", False) if stream: # Handle streaming response async def generate_stream(): try: response = await router.make_request(request_data) for chunk in response: yield f"data: {json.dumps(chunk.model_dump())}\n\n" yield "data: [DONE]\n\n" except Exception as e: error_data = {"error": str(e)} yield f"data: {json.dumps(error_data)}\n\n" return StreamingResponse( generate_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive" } ) else: # Handle non-streaming response response = await router.make_request(request_data) return response.model_dump() except Exception as e: logger.error(f"Request processing failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/switch-provider") async def switch_provider(request: Request): """Manually switch to a specific provider""" provider = await request.json() if provider not in router.providers: raise HTTPException(status_code=400, detail=f"Unknown provider: {provider}") if not router.providers[provider]["active"]: raise HTTPException(status_code=400, detail=f"Provider {provider} is not active") old_provider = router.current_provider router.current_provider = provider logger.info(f"Manually switched from {old_provider} to {provider}") return { "message": f"Switched from {old_provider} to {provider}", "current_provider": router.current_provider } @app.get("/v1/status") async def get_status(): """Get current router status""" return { "current_provider": router.current_provider, "failover_count": router.failover_count, "last_failover": router.last_failover.isoformat() if router.last_failover else None, "last_health_check": router.last_health_check.isoformat() if router.last_health_check else None, "health_check_failures": router.health_check_failures, "providers": router.providers } @app.post("/v1/health-check") async def manual_health_check(): """Manually trigger Claude Pro health check""" try: await router.health_check_claude_pro() return { "message": "Health check completed", "current_provider": router.current_provider, "last_health_check": router.last_health_check.isoformat() if router.last_health_check else None } except Exception as e: raise HTTPException(status_code=500, detail=f"Health check failed: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host=config.host, port=config.port)