Source code for SOAPify.transitions.tracker

"""
This submodule contains the function to build and work with the 'state trackers':
a state tracker is a list of 4-element arrays whose components
represent an 'event':
[previous state ID, current state ID, next state ID, the duration of the event]

In the module there is some logic to help to the user to understand if the event
is a '*first one*' or a '*last one*' (aka the first/las seen in the simulation for
each atom): a *first one* will have `previous state ID = current state ID` and 
a *last one* will have `next state ID = current state ID`
"""
import numpy

from ..classify import SOAPclassification


#: the index of the component of the statetracker with the previous state
TRACK_PREVSTATE = 0
#: the index of the component of the statetracker with the current state
TRACK_CURSTATE = 1
#: the index of the component of the statetracker with the next state
TRACK_ENDSTATE = 2
#: the index of the component of the statetracker with the duration of the state, in frames
TRACK_EVENTTIME = 3


[docs]class StateTracker: """A contained for the state trackers""" stateHistory: list window_: int stride_: int
[docs] def __init__(self, nat: int, window: int, stride: int) -> None: """_summary_ Args: nat (int): _description_ window (int): _description_ stride (int): _description_ """ self.stateHistory = [None for _ in range(nat)] self.window_ = window self.stride_ = stride
@property def window(self) -> int: """_summary_ Returns: int: _description_ """ return self.window_ @property def stride(self) -> int: """_summary_ Returns: int: _description_ """ return self.stride_ def __len__(self) -> int: """_summary_ Returns: int: _description_ """ return len(self.stateHistory) def __getitem__(self, key): """_summary_ Args: key (_type_): the addres of the list of the asked atom """ return self.stateHistory[key] def __setitem__(self, key, data): """_summary_ Args: key (_type_): the addres of the list of the asked atom """ self.stateHistory[key] = data def __iter__(self): """iterate thought the stored list of events""" return iter(self.stateHistory)
def _createEvent( prevState: int, curState: int, endState: int, eventTime: int = 0 ) -> numpy.ndarray: """Compile the given collection of data in a state tracker Args: prevState (int): the id of the previous state curState (int): the id of the current state endState (int): the id of the next state eventTime (int, optional): the duration (in frames) of this event. Defaults to 0. Returns: numpy.ndarray: the state tracker """ return numpy.array([prevState, curState, endState, eventTime], dtype=int)
[docs]def trackStates( classification: SOAPclassification, window: int = 1, stride: "int|None" = None ) -> StateTracker: """Creates an ordered list of events for each atom in the classified trajectory each tracker is composed of events, each events is and array of foir components: `[start state, state, end state,time]`: - if `PREVSTATE == CURSTATE` is the first event for the atom - if `ENDSTATE == CURSTATE` is the last event for the atom Args: classification (SOAPclassification): the classified trajectory stride (int): the stride in frames between each window. Defaults to 1. Returns: StateTracker: ordered list of events for each atom in the classified trajectory """ if stride is None: stride = window if stride > window: raise ValueError("the window must be bigger than the stride") if ( window > classification.references.shape[0] or stride > classification.references.shape[0] ): raise ValueError("stride and window must be smaller than simulation lenght") nofFrames = classification.references.shape[0] nofAtoms = classification.references.shape[1] stateHistory = StateTracker(nofAtoms, window, stride) # TODO: this can be made concurrent per atom for atomID in range(nofAtoms): statesPerAtom = [] atomTraj = classification.references[:, atomID] for iframe in range(0, window, stride): # the array is [start state, state, end state,time] # if PREVSTATE == CURSTATE the event is the first event for the atom # if ENDSTATE == CURSTATE the event is the last event for the atom stateTracker = _createEvent( prevState=atomTraj[iframe], curState=atomTraj[iframe], endState=atomTraj[iframe], eventTime=0, ) for frame in range(window + iframe, nofFrames, window): stateTracker[TRACK_EVENTTIME] += window if atomTraj[frame] != stateTracker[TRACK_CURSTATE]: stateTracker[TRACK_ENDSTATE] = atomTraj[frame] statesPerAtom.append(stateTracker) stateTracker = _createEvent( prevState=stateTracker[TRACK_CURSTATE], curState=atomTraj[frame], endState=atomTraj[frame], ) # append the last event stateTracker[TRACK_EVENTTIME] += window statesPerAtom.append(stateTracker) stateHistory[atomID] = statesPerAtom return stateHistory
[docs]def removeAtomIdentityFromEventTracker(statesTracker: StateTracker) -> StateTracker: """Merge all of the list of stateTracker into a single lists, by removing the information of what atom did a given event. Args: statesTracker (list): a list of list of state trackers, organized by atoms statesTracker (list): a list of list of state trackers, organized by atoms Returns: StateTracker: a state tracker """ if len(statesTracker) > 1: newST = StateTracker( 1, window=statesTracker.window, stride=statesTracker.stride ) newST.stateHistory[0] = [] for tracks in statesTracker.stateHistory: newST.stateHistory[0] += tracks return newST return statesTracker
[docs]def getResidenceTimesFromStateTracker( statesTracker: StateTracker, legend: list ) -> "list[numpy.ndarray]": """Calculates the resindence times from the events. Given a state tracker and the list of the states returns the list of residence times per state Args: statesTracker (list): a list of list of state trackers, organized by atoms, or a list of state trackers legend (list): the list of states Returns: list[numpy.ndarray]: an ordered list of the residence times for each state """ states = removeAtomIdentityFromEventTracker(statesTracker) residenceTimes = [[] for _ in range(len(legend))] # using states[0] because we have removed the atom identities for event in states[0]: residenceTimes[event[TRACK_CURSTATE]].append( event[TRACK_EVENTTIME] if ( event[TRACK_ENDSTATE] != event[TRACK_CURSTATE] and event[TRACK_PREVSTATE] != event[TRACK_CURSTATE] ) else -event[TRACK_EVENTTIME] ) for i, rts in enumerate(residenceTimes): residenceTimes[i] = numpy.sort(numpy.array(rts)) return residenceTimes
[docs]def transitionMatrixFromStateTracker( statesTracker: StateTracker, legend: list ) -> numpy.ndarray: """Generates the unnormalized matrix of the transitions see :func:`calculateTransitionMatrix` for a detailed description of an unnormalized transition matrix Args: statesTracker (StateTracker): a StateTracker legend (list): the list of the name of the states Returns: numpy.ndarray[float]: the unnormalized matrix of the transitions """ states = removeAtomIdentityFromEventTracker(statesTracker) nclasses = len(legend) transMat = numpy.zeros((nclasses, nclasses), dtype=numpy.float64) window = statesTracker.window # using states[0] because we have removed the atom identities for event in states[0]: transMat[event[TRACK_CURSTATE], event[TRACK_CURSTATE]] += ( event[TRACK_EVENTTIME] // window - 1 ) # the transition matrix is genetated with: # classFrom = data.references[frameID - stride][atomID] # classTo = data.references[frameID][atomID] if event[TRACK_PREVSTATE] != event[TRACK_CURSTATE]: transMat[event[TRACK_PREVSTATE], event[TRACK_CURSTATE]] += 1 return transMat