diff --git a/bcb/sgs/__init__.py b/bcb/sgs/__init__.py index 0e7837d..44d81f6 100644 --- a/bcb/sgs/__init__.py +++ b/bcb/sgs/__init__.py @@ -26,6 +26,7 @@ get_client, raise_for_request_error, raise_for_status, + with_retry, ) from bcb.exceptions import SGSError from bcb.utils import Date, DateInput @@ -229,6 +230,16 @@ def _raise_sgs_response_error(res: httpx.Response, code: int) -> None: ) +@with_retry +def _get_sgs_response(url: str, payload: Dict[str, str]) -> httpx.Response: + return get_client().get(url, params=payload) + + +@with_retry +async def _async_get_sgs_response(url: str, payload: Dict[str, str]) -> httpx.Response: + return await get_async_client().get(url, params=payload) + + def _format_df(df: pd.DataFrame, code: SGSCode, freq: Optional[str]) -> pd.DataFrame: cns = {"data": "Date", "valor": code.name, "datafim": "enddate"} df = df.rename(columns=cns) @@ -395,7 +406,7 @@ def get_json( f"Fetching SGS time series code={code_obj.value} from {url.split('/dados')[0]}" ) try: - res = get_client().get(url, params=payload) + res = _get_sgs_response(url, payload) except httpx.HTTPError as ex: raise_for_request_error( ex, context=f"SGS time series code={code_obj.value}", error_cls=SGSError @@ -449,7 +460,7 @@ async def async_get_json( f"from {url.split('/dados')[0]}" ) try: - res = await get_async_client().get(url, params=payload) + res = await _async_get_sgs_response(url, payload) except httpx.HTTPError as ex: raise_for_request_error( ex, context=f"SGS time series code={code_obj.value}", error_cls=SGSError diff --git a/tests/test_async.py b/tests/test_async.py index 77be26b..17c2287 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -38,6 +38,16 @@ SGS_CODE_URL = re.compile(r".*bcdata\.sgs\..*") +async def _async_no_retry_sleep(_): + return None + + +def _disable_async_sgs_retry_sleep(monkeypatch): + monkeypatch.setattr( + sgs._async_get_sgs_response.retry, "sleep", _async_no_retry_sleep + ) + + def add_currency_base_mocks(httpx_mock): httpx_mock.add_response( url=PTAX_ID_LIST_URL, @@ -134,6 +144,23 @@ async def test_async_get_json_rate_limit_raises(httpx_mock): await sgs.async_get_json(1) +async def test_async_get_json_retries_timeout_then_succeeds(httpx_mock, monkeypatch): + _disable_async_sgs_retry_sleep(monkeypatch) + httpx_mock.add_exception( + httpx.TimeoutException("request timed out"), + url=SGS_CODE_URL, + ) + httpx_mock.add_response( + url=SGS_CODE_URL, + text=SGS_JSON_5, + status_code=200, + ) + + result = await sgs.async_get_json(1) + + assert result == SGS_JSON_5 + + async def test_async_get_empty_sgs_code_list_raises(): with pytest.raises(ValueError, match="At least one SGS code"): await sgs.async_get([]) diff --git a/tests/test_sgs_negative.py b/tests/test_sgs_negative.py index 339a114..86b4cdd 100644 --- a/tests/test_sgs_negative.py +++ b/tests/test_sgs_negative.py @@ -11,10 +11,15 @@ from bcb import sgs from bcb.exceptions import SGSError +from tests.conftest import SGS_JSON_5 SGS_CODE_URL = re.compile(r".*bcdata\.sgs\..*") +def _disable_sgs_retry_sleep(monkeypatch): + monkeypatch.setattr(sgs._get_sgs_response.retry, "sleep", lambda _: None) + + # --------------------------------------------------------------------------- # 404 and 429 API response errors # --------------------------------------------------------------------------- @@ -52,25 +57,57 @@ def test_get_json_500_raises(httpx_mock): sgs.get_json(1) -def test_get_json_connection_error_raises(httpx_mock): - httpx_mock.add_exception( - httpx.ConnectError("network down"), - url=SGS_CODE_URL, - ) +def test_get_json_connection_error_raises(httpx_mock, monkeypatch): + _disable_sgs_retry_sleep(monkeypatch) + for _ in range(4): + httpx_mock.add_exception( + httpx.ConnectError("network down"), + url=SGS_CODE_URL, + ) with pytest.raises(SGSError, match="SGS time series"): sgs.get_json(1) -def test_get_json_timeout_error_raises(httpx_mock): +def test_get_json_timeout_error_raises(httpx_mock, monkeypatch): + _disable_sgs_retry_sleep(monkeypatch) + for _ in range(4): + httpx_mock.add_exception( + httpx.TimeoutException("request timed out"), + url=SGS_CODE_URL, + ) + + with pytest.raises(SGSError, match="SGS time series"): + sgs.get_json(1) + + +def test_get_json_retries_timeout_then_succeeds(httpx_mock, monkeypatch): + _disable_sgs_retry_sleep(monkeypatch) httpx_mock.add_exception( httpx.TimeoutException("request timed out"), url=SGS_CODE_URL, ) + httpx_mock.add_response( + url=SGS_CODE_URL, + text=SGS_JSON_5, + status_code=200, + ) - with pytest.raises(SGSError, match="SGS time series"): + assert sgs.get_json(1) == SGS_JSON_5 + + +def test_get_json_does_not_retry_http_status_errors(httpx_mock, monkeypatch): + _disable_sgs_retry_sleep(monkeypatch) + httpx_mock.add_response( + url=SGS_CODE_URL, + status_code=500, + ) + + with pytest.raises(SGSError): sgs.get_json(1) + assert len(httpx_mock.get_requests()) == 1 + # --------------------------------------------------------------------------- # Malformed data (JSON)