113 lines
3.0 KiB
Python
113 lines
3.0 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Dict, List, Optional, Union, Any
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
import time
|
|
import logging
|
|
|
|
class BaseAPIClient(ABC):
|
|
"""Base class for all API clients."""
|
|
|
|
def __init__(self, api_key: Optional[str] = None):
|
|
"""Initialize base API client.
|
|
|
|
Args:
|
|
api_key: Optional API key
|
|
"""
|
|
self.api_key = api_key
|
|
self.last_request_time = None
|
|
self.rate_limit_delay = 1.0 # Default 1 second between requests
|
|
self.logger = logging.getLogger(self.__class__.__name__)
|
|
|
|
@abstractmethod
|
|
def get_etf_profile(self, symbol: str) -> Dict:
|
|
"""Get ETF profile data.
|
|
|
|
Args:
|
|
symbol: ETF ticker symbol
|
|
|
|
Returns:
|
|
Dictionary containing ETF profile information
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_etf_holdings(self, symbol: str) -> List[Dict]:
|
|
"""Get ETF holdings data.
|
|
|
|
Args:
|
|
symbol: ETF ticker symbol
|
|
|
|
Returns:
|
|
List of dictionaries containing holding information
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_historical_data(self, symbol: str, period: str = '1y') -> pd.DataFrame:
|
|
"""Get historical price data.
|
|
|
|
Args:
|
|
symbol: ETF ticker symbol
|
|
period: Time period (e.g., '1d', '1w', '1m', '1y')
|
|
|
|
Returns:
|
|
DataFrame with historical price data
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_dividend_history(self, symbol: str) -> pd.DataFrame:
|
|
"""Get dividend history.
|
|
|
|
Args:
|
|
symbol: ETF ticker symbol
|
|
|
|
Returns:
|
|
DataFrame with dividend history
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_sector_weightings(self, symbol: str) -> Dict:
|
|
"""Get sector weightings.
|
|
|
|
Args:
|
|
symbol: ETF ticker symbol
|
|
|
|
Returns:
|
|
Dictionary with sector weightings
|
|
"""
|
|
pass
|
|
|
|
def _check_rate_limit(self):
|
|
"""Check and enforce rate limiting."""
|
|
if self.last_request_time:
|
|
time_since_last = (datetime.now() - self.last_request_time).total_seconds()
|
|
if time_since_last < self.rate_limit_delay:
|
|
time.sleep(self.rate_limit_delay - time_since_last)
|
|
self.last_request_time = datetime.now()
|
|
|
|
@abstractmethod
|
|
def _validate_symbol(self, symbol: str) -> bool:
|
|
"""Validate a symbol.
|
|
|
|
Args:
|
|
symbol: Symbol to validate
|
|
|
|
Returns:
|
|
True if valid, False otherwise
|
|
"""
|
|
pass
|
|
|
|
def _handle_error(self, error: Exception) -> Dict:
|
|
"""Handle API errors.
|
|
|
|
Args:
|
|
error: Exception that occurred
|
|
|
|
Returns:
|
|
Error response dictionary
|
|
"""
|
|
self.logger.error(f"API error: {str(error)}")
|
|
return {"error": str(error)} |