Browse Source

exchange_rate: add some type hints

patch-4
SomberNight 3 years ago
parent
commit
81d0928abd
No known key found for this signature in database GPG Key ID: B33B5F232C6271E9
  1. 26
      electrum/exchange_rate.py

26
electrum/exchange_rate.py

@ -8,7 +8,7 @@ import time
import csv import csv
import decimal import decimal
from decimal import Decimal from decimal import Decimal
from typing import Sequence, Optional from typing import Sequence, Optional, Mapping, Dict, Union
from aiorpcx.curio import timeout_after, TaskTimeout from aiorpcx.curio import timeout_after, TaskTimeout
import aiohttp import aiohttp
@ -41,8 +41,8 @@ class ExchangeBase(Logger):
def __init__(self, on_quotes, on_history): def __init__(self, on_quotes, on_history):
Logger.__init__(self) Logger.__init__(self)
self.history = {} self.history = {} # type: Dict[str, Dict[str, Union[str, float, Decimal]]]
self.quotes = {} self.quotes = {} # type: Dict[str, Union[str, float, Decimal, None]]
self.on_quotes = on_quotes self.on_quotes = on_quotes
self.on_history = on_history self.on_history = on_history
@ -75,7 +75,7 @@ class ExchangeBase(Logger):
def name(self): def name(self):
return self.__class__.__name__ return self.__class__.__name__
async def update_safe(self, ccy): async def update_safe(self, ccy: str) -> None:
try: try:
self.logger.info(f"getting fx quotes for {ccy}") self.logger.info(f"getting fx quotes for {ccy}")
self.quotes = await self.get_rates(ccy) self.quotes = await self.get_rates(ccy)
@ -88,7 +88,7 @@ class ExchangeBase(Logger):
self.quotes = {} self.quotes = {}
self.on_quotes() self.on_quotes()
def read_historical_rates(self, ccy, cache_dir) -> Optional[dict]: def read_historical_rates(self, ccy: str, cache_dir: str) -> Optional[dict]:
filename = os.path.join(cache_dir, self.name() + '_'+ ccy) filename = os.path.join(cache_dir, self.name() + '_'+ ccy)
if not os.path.exists(filename): if not os.path.exists(filename):
return None return None
@ -106,7 +106,7 @@ class ExchangeBase(Logger):
return h return h
@log_exceptions @log_exceptions
async def get_historical_rates_safe(self, ccy, cache_dir): async def get_historical_rates_safe(self, ccy: str, cache_dir: str) -> None:
try: try:
self.logger.info(f"requesting fx history for {ccy}") self.logger.info(f"requesting fx history for {ccy}")
h = await self.request_history(ccy) h = await self.request_history(ccy)
@ -124,7 +124,7 @@ class ExchangeBase(Logger):
self.history[ccy] = h self.history[ccy] = h
self.on_history() self.on_history()
def get_historical_rates(self, ccy, cache_dir): def get_historical_rates(self, ccy: str, cache_dir: str) -> None:
if ccy not in self.history_ccys(): if ccy not in self.history_ccys():
return return
h = self.history.get(ccy) h = self.history.get(ccy)
@ -133,19 +133,19 @@ class ExchangeBase(Logger):
if h is None or h['timestamp'] < time.time() - 24*3600: if h is None or h['timestamp'] < time.time() - 24*3600:
asyncio.get_event_loop().create_task(self.get_historical_rates_safe(ccy, cache_dir)) asyncio.get_event_loop().create_task(self.get_historical_rates_safe(ccy, cache_dir))
def history_ccys(self): def history_ccys(self) -> Sequence[str]:
return [] return []
def historical_rate(self, ccy, d_t): def historical_rate(self, ccy: str, d_t: datetime) -> Union[str, float, Decimal]:
return self.history.get(ccy, {}).get(d_t.strftime('%Y-%m-%d'), 'NaN') return self.history.get(ccy, {}).get(d_t.strftime('%Y-%m-%d'), 'NaN')
async def request_history(self, ccy): async def request_history(self, ccy: str) -> Dict[str, Union[str, float, Decimal]]:
raise NotImplementedError() # implemented by subclasses raise NotImplementedError() # implemented by subclasses
async def get_rates(self, ccy): async def get_rates(self, ccy: str) -> Mapping[str, Union[str, float, Decimal, None]]:
raise NotImplementedError() # implemented by subclasses raise NotImplementedError() # implemented by subclasses
async def get_currencies(self): async def get_currencies(self) -> Sequence[str]:
rates = await self.get_rates('') rates = await self.get_rates('')
return sorted([str(a) for (a, b) in rates.items() if b is not None and len(a)==3]) return sorted([str(a) for (a, b) in rates.items() if b is not None and len(a)==3])
@ -489,7 +489,7 @@ class FxThread(ThreadJob):
self.history_used_spot = False self.history_used_spot = False
self.ccy_combo = None self.ccy_combo = None
self.hist_checkbox = None self.hist_checkbox = None
self.cache_dir = os.path.join(config.path, 'cache') self.cache_dir = os.path.join(config.path, 'cache') # type: str
self._trigger = asyncio.Event() self._trigger = asyncio.Event()
self._trigger.set() self._trigger.set()
self.set_exchange(self.config_exchange()) self.set_exchange(self.config_exchange())

Loading…
Cancel
Save