Commit 2da32f66 authored by Aaron Spring's avatar Aaron Spring 🚼
Browse files

write rps loss function

parent 7ddae19b
Pipeline #208313 passed with stage
in 9 minutes and 11 seconds
%% Cell type:markdown id: tags:
# Train ML model to correct predictions of week 3-4 & 5-6
This notebook create a Machine Learning `ML_model` to predict weeks 3-4 & 5-6 based on `S2S` weeks 3-4 & 5-6 forecasts and is compared to `CPC` observations for the [`s2s-ai-challenge`](https://s2s-ai-challenge.github.io/).
%% Cell type:markdown id: tags:
# Synopsis
%% Cell type:markdown id: tags:
## Method: `ML-based mean bias reduction`
- calculate the ML-based bias from 2000-2019 deterministic ensemble mean forecast
- remove that the ML-based bias from 2020 forecast deterministic ensemble mean forecast
%% Cell type:markdown id: tags:
## Data used
type: renku datasets
Training-input for Machine Learning model:
- hindcasts of models:
- ECMWF: `ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr`
Forecast-input for Machine Learning model:
- real-time 2020 forecasts of models:
- ECMWF: `ecmwf_forecast-input_2020_biweekly_deterministic.zarr`
Compare Machine Learning model forecast against against ground truth:
- `CPC` observations:
- `hindcast-like-observations_biweekly_deterministic.zarr`
- `forecast-like-observations_2020_biweekly_deterministic.zarr`
%% Cell type:markdown id: tags:
## Resources used
for training, details in reproducibility
- platform: renku
- memory: 8 GB
- processors: 2 CPU
- storage required: 10 GB
%% Cell type:markdown id: tags:
## Safeguards
All points have to be [x] checked. If not, your submission is invalid.
Changes to the code after submissions are not possible, as the `commit` before the `tag` will be reviewed.
(Only in exceptions and if previous effort in reproducibility can be found, it may be allowed to improve readability and reproducibility after November 1st 2021.)
%% Cell type:markdown id: tags:
### Safeguards to prevent [overfitting](https://en.wikipedia.org/wiki/Overfitting?wprov=sfti1)
If the organizers suspect overfitting, your contribution can be disqualified.
- [x] We did not use 2020 observations in training (explicit overfitting and cheating)
- [x] We did not repeatedly verify my model on 2020 observations and incrementally improved my RPSS (implicit overfitting)
- [x] We provide RPSS scores for the training period with script `print_RPS_per_year`, see in section 6.3 `predict`.
- [x] We tried our best to prevent [data leakage](https://en.wikipedia.org/wiki/Leakage_(machine_learning)?wprov=sfti1).
- [x] We honor the `train-validate-test` [split principle](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets). This means that the hindcast data is split into `train` and `validate`, whereas `test` is withheld.
- [x] We did not use `test` explicitly in training or implicitly in incrementally adjusting parameters.
- [x] We considered [cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics)).
%% Cell type:markdown id: tags:
### Safeguards for Reproducibility
Notebook/code must be independently reproducible from scratch by the organizers (after the competition), if not possible: no prize
- [x] All training data is publicly available (no pre-trained private neural networks, as they are not reproducible for us)
- [x] Code is well documented, readable and reproducible.
- [x] Code to reproduce training and predictions is preferred to run within a day on the described architecture. If the training takes longer than a day, please justify why this is needed. Please do not submit training piplelines, which take weeks to train.
%% Cell type:markdown id: tags:
# Todos to improve template
This is just a demo.
- [ ] use multiple predictor variables and two predicted variables
- [ ] for both `lead_time`s in one go
- [ ] consider seasonality, for now all `forecast_time` months are mixed
- [ ] make probabilistic predictions with `category` dim, for now works deterministic
%% Cell type:markdown id: tags:
# Imports
%% Cell type:code id: tags:
``` python
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow import keras
import matplotlib.pyplot as plt
import xarray as xr
xr.set_options(display_style='text')
import numpy as np
from dask.utils import format_bytes
import xskillscore as xs
%load_ext tensorboard
```
%%%% Output: stream
/opt/conda/lib/python3.8/site-packages/xarray/backends/cfgrib_.py:27: UserWarning: Failed to load cfgrib - most likely there is a problem accessing the ecCodes library. Try `import cfgrib` to get the full error message
warnings.warn(
%% Cell type:markdown id: tags:
# Get training data
preprocessing of input data may be done in separate notebook/script
%% Cell type:markdown id: tags:
## Hindcast
get weekly initialized hindcasts
%% Cell type:code id: tags:
``` python
v='t2m'
```
%% Cell type:code id: tags:
``` python
# preprocessed as renku dataset
!renku storage pull ../data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr
```
%%%% Output: stream
Warning: Run CLI commands only from project's root directory.

%% Cell type:code id: tags:
``` python
hind_2000_2019 = xr.open_zarr("../data/ecmwf_hindcast-input_2000-2019_biweekly_deterministic.zarr", consolidated=True)
```
%%%% Output: stream
/opt/conda/lib/python3.8/site-packages/xarray/backends/plugins.py:61: RuntimeWarning: Engine 'cfgrib' loading failed:
/opt/conda/lib/python3.8/site-packages/gribapi/_bindings.cpython-38-x86_64-linux-gnu.so: undefined symbol: codes_bufr_key_is_header
warnings.warn(f"Engine {name!r} loading failed:\n{ex}", RuntimeWarning)
%% Cell type:markdown id: tags:
## Forecast
%% Cell type:code id: tags:
``` python
# preprocessed as renku dataset
!renku storage pull ../data/ecmwf_forecast-input_2020_biweekly_deterministic.zarr
```
%%%% Output: stream
Warning: Run CLI commands only from project's root directory.

%% Cell type:code id: tags:
``` python
fct_2020 = xr.open_zarr("../data/ecmwf_forecast-input_2020_biweekly_deterministic.zarr", consolidated=True)
```
%% Cell type:markdown id: tags:
## Observations
corresponding to hindcasts
%% Cell type:code id: tags:
``` python
# preprocessed as renku dataset
#!renku storage pull ../data/hindcast-like-observations_2000-2019_biweekly_deterministic.zarr
```
%% Cell type:code id: tags:
``` python
#obs_2000_2019 = xr.open_zarr("../data/hindcast-like-observations_2000-2019_biweekly_deterministic.zarr", consolidated=True)#[v]
```
%% Cell type:code id: tags:
``` python
# preprocessed as renku dataset
#!renku storage pull ../data/forecast-like-observations_2020_biweekly_deterministic.zarr
```
%% Cell type:code id: tags:
``` python
#obs_2020 = xr.open_zarr("../data/forecast-like-observations_2020_biweekly_deterministic.zarr", consolidated=True)#[v]
```
%% Cell type:code id: tags:
``` python
```
categorized in terciles corresponding to hindcasts/forecasts
%% Cell type:code id: tags:
``` python
!renku storage pull ../data/hindcast-like-observations_2000-2019_biweekly_terciled.zarr
```
%%%% Output: stream
Warning: Run CLI commands only from project's root directory.

%% Cell type:code id: tags:
``` python
obs_2000_2019_p = xr.open_dataset(f'../data/hindcast-like-observations_2000-2019_biweekly_terciled.zarr', engine='zarr')
obs_2000_2019_p.sizes
```
%%%% Output: execute_result
Frozen(SortedKeysDict({'category': 3, 'forecast_time': 1060, 'latitude': 121, 'lead_time': 2, 'longitude': 240}))
%% Cell type:code id: tags:
``` python
obs_2000_2019_p = obs_2000_2019_p.assign_coords(category=[0,1,2])
obs_2000_2019_p = (obs_2000_2019_p * obs_2000_2019_p.category).sum('category', skipna=False).compute()
obs_2000_2019_p2 = obs_2000_2019_p.assign_coords(category=[0,1,2])
obs_2000_2019_p2 = (obs_2000_2019_p2 * obs_2000_2019_p2.category).sum('category', skipna=False).compute()
#obs_2000_2019_p.isel(forecast_time=[2,4]).t2m.plot(col='lead_time', row='forecast_time')
#obs_2000_2019_p2.isel(forecast_time=[2,4]).t2m.plot(col='lead_time', row='forecast_time')
```
%% Cell type:code id: tags:
``` python
obs_2000_2019_p[v].isel(lead_time=1,forecast_time=2).plot()
obs_2020_p = xr.open_dataset(f'../data/forecast-like-observations_2020_biweekly_terciled.nc')
obs_2020_p2 = obs_2020_p.assign_coords(category=[0,1,2])
obs_2020_p2 = (obs_2020_p2 * obs_2020_p2.category).sum('category', skipna=False).compute()
```
%%%% Output: execute_result
<matplotlib.collections.QuadMesh at 0x7f6937fce400>
%%%% Output: display_data
![]()
%% Cell type:markdown id: tags:
# ML model
%% Cell type:markdown id: tags:
based on [Weatherbench](https://github.com/pangeo-data/WeatherBench/blob/master/quickstart.ipynb)
%% Cell type:code id: tags:
``` python
# run once only and dont commit
!git clone https://github.com/pangeo-data/WeatherBench/
```
%%%% Output: stream
fatal: destination path 'WeatherBench' already exists and is not an empty directory.
%% Cell type:code id: tags:
``` python
import sys
sys.path.insert(1, 'WeatherBench')
from WeatherBench.src.train_nn import DataGenerator, PeriodicConv2D, create_predictions
import tensorflow.keras as keras
```
%% Cell type:code id: tags:
``` python
bs=32
import numpy as np
class DataGenerator(keras.utils.Sequence):
def __init__(self, fct, verif, lead_time, batch_size=bs, shuffle=True, load=True,
mean=None, std=None):
def __init__(self, fct, verif, lead_time, batch_size=bs, shuffle=True, load=True):
"""
Data generator for WeatherBench data.
Template from https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
Args:
fct: forecasts from S2S models: xr.DataArray (xr.Dataset doesnt work properly)
verif: observations with same dimensionality (xr.Dataset doesnt work properly)
lead_time: Lead_time as in model
batch_size: Batch size
shuffle: bool. If True, data is shuffled.
load: bool. If True, datadet is loaded into RAM.
mean: If None, compute mean from data.
std: If None, compute standard deviation from data.
Todo:
- use number in a better way, now uses only ensemble mean forecast
- dont use .sel(lead_time=lead_time) to train over all lead_time at once
- be sensitive with forecast_time, pool a few around the weekofyear given
- use more variables as predictors
- predict more variables
"""
if isinstance(fct, xr.Dataset):
print('convert fct to array')
fct = fct.to_array().transpose(...,'variable')
self.fct_dataset=True
else:
self.fct_dataset=False
if isinstance(verif, xr.Dataset):
print('convert verif to array')
verif = verif.to_array().transpose(...,'variable')
self.verif_dataset=True
else:
self.verif_dataset=False
#self.fct = fct
self.batch_size = batch_size
self.shuffle = shuffle
self.lead_time = lead_time
self.fct_data = fct.transpose('forecast_time', ...).sel(lead_time=lead_time)
self.fct_mean = self.fct_data.mean('forecast_time').compute() if mean is None else mean
self.fct_std = self.fct_data.std('forecast_time').compute() if std is None else std
self.fct_mean = self.fct_data.mean('forecast_time').compute()
self.fct_std = self.fct_data.std('forecast_time').compute()
self.verif_data = verif.transpose('forecast_time', ...).sel(lead_time=lead_time)
self.verif_mean = self.verif_data.mean('forecast_time').compute() if mean is None else mean
self.verif_std = self.verif_data.std('forecast_time').compute() if std is None else std
#self.verif_mean = self.verif_data.mean('forecast_time').compute() if mean is None else mean
#self.verif_std = self.verif_data.std('forecast_time').compute() if std is None else std
# Normalize
self.fct_data = (self.fct_data - self.fct_mean) / self.fct_std
# self.verif_data = (self.verif_data - self.verif_mean) / self.verif_std
#self.verif_data = (self.verif_data - self.verif_mean) / self.verif_std
#self.verif_data = self.verif_data.astype('int32')
self.n_samples = self.fct_data.forecast_time.size
self.forecast_time = self.fct_data.forecast_time
self.on_epoch_end()
# For some weird reason calling .load() earlier messes up the mean and std computations
if load:
# print('Loading data into RAM')
self.fct_data.load()
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.ceil(self.n_samples / self.batch_size))
def __getitem__(self, i):
'Generate one batch of data'
idxs = self.idxs[i * self.batch_size:(i + 1) * self.batch_size]
# got all nan if nans not masked
X = self.fct_data.isel(forecast_time=idxs).fillna(0.).values
y = self.verif_data.isel(forecast_time=idxs).fillna(0.).values
X = self.fct_data.isel(forecast_time=idxs).values
y = self.verif_data.isel(forecast_time=idxs).values
return X, y
def on_epoch_end(self):
'Updates indexes after each epoch'
self.idxs = np.arange(self.n_samples)
if self.shuffle == True:
np.random.shuffle(self.idxs)
```
%% Cell type:code id: tags:
``` python
# 2 bi-weekly `lead_time`: week 3-4
lead = hind_2000_2019.isel(lead_time=0).lead_time
lead
```
%%%% Output: execute_result
<xarray.DataArray 'lead_time' ()>
array(1209600000000000, dtype='timedelta64[ns]')
Coordinates:
lead_time timedelta64[ns] 14 days
Attributes:
comment: lead_time describes bi-weekly aggregates. The pd.Timedelta corr...
%% Cell type:code id: tags:
``` python
# mask, needed?
hind_2000_2019 = hind_2000_2019.where(obs_2000_2019_p.isel(forecast_time=0, lead_time=0,drop=True).notnull())
```
%% Cell type:code id: tags:
``` python
mask = obs_2000_2019_p.isel(forecast_time=0, lead_time=0,drop=True).notnull()
mask = obs_2000_2019_p2.isel(forecast_time=0, lead_time=0,drop=True).notnull()
```
%% Cell type:markdown id: tags:
## data prep: train, valid, test
[Use the hindcast period to split train and valid.](https://en.wikipedia.org/wiki/Training,_validation,_and_test_sets) Do not use the 2020 data for testing!
%% Cell type:code id: tags:
``` python
# time is the forecast_time
time_train_start,time_train_end='2000','2015' # train
time_valid_start,time_valid_end='2016','2019' # valid
time_train_start,time_train_end='2000','2017' # train
time_valid_start,time_valid_end='2018','2019' # valid
time_test = '2020' # test
```
%% Cell type:code id: tags:
``` python
# Jan only
season='DJF'
tattr='month'
tattr_label=1
attr='seaon'
attr_label='DJF'
bs=16
hind_2000_2019 = hind_2000_2019.sel(forecast_time=hind_2000_2019.forecast_time.dt.season==season)#.sel(forecast_time=hind_2000_2019.forecast_time.dt.day==2)
obs_2000_2019_p = obs_2000_2019_p.sel(forecast_time=obs_2000_2019_p.forecast_time.dt.season==season)#.sel(obs_2000_2019_p=hind_2000_2019.forecast_time.dt.day==2)
hind_2000_2019 = hind_2000_2019.sel(forecast_time=getattr(hind_2000_2019.forecast_time.dt, tattr)==tattr_label)
obs_2000_2019_p2 = obs_2000_2019_p2.sel(forecast_time=getattr(obs_2000_2019_p2.forecast_time.dt, tattr)==tattr_label)
```
%% Cell type:code id: tags:
``` python
dg_train = DataGenerator(
hind_2000_2019.mean('realization').sel(forecast_time=slice(time_train_start,time_train_end))[v],
obs_2000_2019_p.sel(forecast_time=slice(time_train_start,time_train_end))[v],
obs_2000_2019_p2.sel(forecast_time=slice(time_train_start,time_train_end))[v],
lead_time=lead, batch_size=bs, shuffle=True, load=True)
```
%%%% Output: stream
/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide
x = np.divide(x1, x2, out)
/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide
x = np.divide(x1, x2, out)
/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide
x = np.divide(x1, x2, out)
/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide
x = np.divide(x1, x2, out)
%% Cell type:code id: tags:
``` python
dg_valid = DataGenerator(
hind_2000_2019.mean('realization').sel(forecast_time=slice(time_valid_start,time_valid_end))[v],
obs_2000_2019_p.sel(forecast_time=slice(time_valid_start,time_valid_end))[v],
obs_2000_2019_p2.sel(forecast_time=slice(time_valid_start,time_valid_end))[v],
lead_time=lead, batch_size=bs, shuffle=False, load=True)
```
%%%% Output: stream
/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide
x = np.divide(x1, x2, out)
/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide
x = np.divide(x1, x2, out)
/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide
x = np.divide(x1, x2, out)
/opt/conda/lib/python3.8/site-packages/dask/array/numpy_compat.py:40: RuntimeWarning: invalid value encountered in true_divide
x = np.divide(x1, x2, out)
%% Cell type:code id: tags:
``` python
# do not use, delete?
#dg_test = DataGenerator(
# fct_2020.mean('realization').sel(forecast_time=time_test)[v],
# obs_2020.sel(forecast_time=time_test)[v],
# lead_time=lead, batch_size=bs, load=True, mean=dg_train.fct_mean, std=dg_train.fct_std, shuffle=False)
dg_test = DataGenerator(
fct_2020.mean('realization').sel(forecast_time=time_test)[v],
obs_2020_p2.sel(forecast_time=time_test)[v],
lead_time=lead, batch_size=bs, shuffle=False, load=True)
```
%% Cell type:code id: tags:
``` python
X, y = dg_train[0]
X.shape, y.shape
```
%%%% Output: execute_result
((16, 121, 240), (16, 121, 240))
%% Cell type:code id: tags:
``` python
X, y = dg_valid[0]
X.shape, y.shape
```
%%%% Output: execute_result
((16, 121, 240), (16, 121, 240))
((10, 121, 240), (10, 121, 240))
%% Cell type:code id: tags:
``` python
# short look into training data
# any problem from normalizing?
i=4
xr.DataArray(np.vstack([X[i],y[i]])).plot(yincrease=False, robust=True)
#xr.DataArray(np.vstack([X[i],y[i]])).plot(yincrease=False, robust=True)
```
%%%% Output: execute_result
<matplotlib.collections.QuadMesh at 0x7f69544ffbb0>
%%%% Output: display_data
![]()
%% Cell type:markdown id: tags:
## `fit`
%% Cell type:code id: tags:
``` python
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, Flatten, Dense, Dropout
```
%% Cell type:code id: tags:
``` python
y.shape
```
%%%% Output: execute_result
(16, 121, 240)
%% Cell type:code id: tags:
``` python
dr=.01
cnn = keras.models.Sequential([
PeriodicConv2D(filters=32, kernel_size=5, conv_kwargs={'activation':'elu'},