Source code for app.explore.summary

"""
Functions to summarise session and trial data.
"""

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

sns.set_style("darkgrid")


[docs]def trial_event_flow(all_data: np.ndarray, session_id: int, trial_id: int): """Prints events in a trial sequentially. Parameters ---------- all_data: np.ndarray A 2-d numpy array that contains data from all sessions. session_id: int Integer that denotes a particular session. trial_id: int Integer that denotes a trial within a session. """ session_data = all_data[session_id] print(f"Session number = {session_id}") print(f"Trial number = {trial_id}/{session_data['spks'].shape[1]}") print(f"Contrast left: {session_data['contrast_left'][trial_id]}") print(f"Contrast Right: {session_data['contrast_right'][trial_id]}") print(f"Response: {session_data['response'][trial_id]}") print(f"Feedback: {session_data['feedback_type'][trial_id]}\n") print(f"\n{'Time':<8} - {'Action':<10}") print("-" * 36) print(f"{'0':<8} - {'start':<10}") print(f"{session_data['stim_onset']:<8} - {'stim_onset (always fixed)':<10}") print(f"{round(session_data['gocue'][trial_id].item(), 3):<8} - {'gocue':<10}") print( f"{round(session_data['response_time'][trial_id].item(), 3):<8} - {'response_time':<10}" ) print( f"{round(session_data['feedback_time'][trial_id].item(), 3):<8} - {'feedback_time':<10}" ) print(f"{'NA':<8} - {'end':<10}")
[docs]def session_stats(all_data: np.ndarray, session_id: int): """ Prints introductory information about a session and size of all variables. Parameters ---------- all_data: np.ndarray A 2-d numpy array that contains data from all sessions. session_id: int Integer that denotes a particular session. """ print(f"Number of sessions = {len(all_data)}\n\n") print(f"Stats for a session #{session_id}: \n") print( f"\tNumber of neurons used in this session = {all_data[session_id]['spks'].shape[0]}" ) print( f"\tNumber of trials in this session = {all_data[session_id]['spks'].shape[1]}" ) print(f"\tTime taken per trial = {all_data[session_id]['spks'].shape[2]}\n") print("-" * 50) print("\nData shapes:\n") session_keys = all_data[session_id].keys() for k in session_keys: if type(all_data[session_id][k]) == np.ndarray: print(f"\t{k} : {all_data[session_id][k].shape}") elif type(all_data[session_id][k]) == list: print(f"\t{k} : {len(all_data[session_id][k])}") elif type(all_data[session_id][k]) == float: print(f"\t{k} : {all_data[session_id][k]}")
[docs]def session_accuracy_report(all_data: np.ndarray, session_id: int, plot: bool) -> float: """Returns response accuracy of a mouse in a single session. Can optionally plot the confusion matrix. Parameters ---------- all_data: np.ndarray A 2-d numpy array that contains data from all sessions. session_id: int Integer that denotes a particular session. plot: bool, optional Plots a confusion matrix if True. Returns ------- float Returns response accuracy. """ # -1 for right, +1 for left, 0 for center # in session_data["response"]. We remap it to # 2, 1, and 0. session_data = all_data[session_id] idx2class = {2: "right", 0: "center", 1: "left"} # 0:center, 1: left, 2:right true_output = [] for l, r in zip( session_data["contrast_left"].tolist(), session_data["contrast_right"].tolist() ): if r > l: true_output.append(2) elif l > r: true_output.append(1) else: true_output.append(0) # 0:center, 1: left, 2:right pred_output = session_data["response"].tolist() pred_output = [int(i) for i in pred_output] pred_output = [2 if i == -1 else i for i in pred_output] if plot: print(classification_report(true_output, pred_output)) df = pd.DataFrame(confusion_matrix(true_output, pred_output)).rename( columns=idx2class, index=idx2class ) sns.heatmap(df, annot=True) plt.xlabel("Predicted") plt.ylabel("Actual") plt.title("Confusion Matrix") acc = accuracy_score(true_output, pred_output) return acc * 100
[docs]def session_accuracy(all_data: np.ndarray, session_id: int): """Returns the average response accuracy for all trials in a session. Uses 'feedback_type' to calculate the accuracy. Parameters ----------- all_data: np.ndarray A 2-d numpy array that contains data from all sessions. session_id: int Integer that denotes a particular session. Returns ------- float Accuracy percentage. """ session_data = all_data[session_id] session_feedback = session_data["feedback_type"] session_feedback = np.where(session_feedback == -1, 0, 1) session_acc = session_feedback.mean() return session_acc * 100
[docs]def get_mouse_sessions(all_data: np.ndarray, mouse_name: str) -> list: """ Return session-ids that a single mouse participated in. Parameters ----------- all_data: np.ndarray 3-d numpy array that contains data from all sessions. mosue_name: str Name of mouse. Returns ------- list List of sessions for a particular mouse. """ return [i for i in range(len(all_data)) if all_data[i]["mouse_name"] == mouse_name]