"""
Functions to access Steinmetz dataset.
"""
import numpy as np
import os, requests
from tqdm.notebook import tqdm
[docs]def download_data(data_path: str):
"""Download the Steinmetz dataset.
Parameters
----------
data_path: str
Path to save the dataset.
"""
print("Downloading data.")
fname = []
for i in range(3):
fname.append(f"{data_path}steinmetz_part{i}.npz")
url = [
"https://osf.io/agvxh/download",
"https://osf.io/uv3mw/download",
"https://osf.io/ehmw2/download",
]
for i in tqdm(range(len(url))):
if not os.path.isfile(fname[i]):
try:
r = requests.get(url[i])
except requests.ConnectionError:
print("!!! Failed to download data !!!")
else:
if r.status_code != requests.codes.ok:
print("!!! Failed to download data !!!")
else:
with open(fname[i], "wb") as fid:
fid.write(r.content)
[docs]def load_data(data_path: str) -> np.ndarray:
"""Load dataset from the given path.
If dataset does not exist, it is auto downlaoded.
Parameters
----------
data_path: str
Path where the dataset is saved.
Returns
-------
np.ndarray
Numpy array which contains the data.
"""
# download data if not already present.
if len(os.listdir(data_path)) == 0:
download_data(data_path)
# load dataset
print("Loading data.")
all_data = np.array([])
for i in tqdm(range(len(os.listdir(data_path)))):
all_data = np.hstack(
(
all_data,
np.load(f"{data_path}steinmetz_part{i}.npz", allow_pickle=True)["dat"],
)
)
# Apply corrections to the following time based variables.
for i in range(len(all_data)):
all_data[i]["gocue"] += all_data[i]["stim_onset"]
all_data[i]["response_time"] += all_data[i]["stim_onset"]
all_data[i]["feedback_time"] += all_data[i]["stim_onset"]
# squeeze all extra dimensions
for i in range(len(all_data)):
for k in all_data[i].keys():
if type(all_data[i][k]) == np.ndarray:
all_data[i][k] = all_data[i][k].squeeze()
return all_data