|
|
|
|
|
|
|
|
|
from uuid import UUID |
|
from io import StringIO |
|
from typing import Callable |
|
from antlr4.Token import Token |
|
from antlr4.atn.ATN import ATN |
|
from antlr4.atn.ATNType import ATNType |
|
from antlr4.atn.ATNState import * |
|
from antlr4.atn.Transition import * |
|
from antlr4.atn.LexerAction import * |
|
from antlr4.atn.ATNDeserializationOptions import ATNDeserializationOptions |
|
|
|
|
|
BASE_SERIALIZED_UUID = UUID("AADB8D7E-AEEF-4415-AD2B-8204D6CF042E") |
|
|
|
|
|
|
|
|
|
ADDED_UNICODE_SMP = UUID("59627784-3BE5-417A-B9EB-8131A7286089") |
|
|
|
|
|
|
|
SUPPORTED_UUIDS = [ BASE_SERIALIZED_UUID, ADDED_UNICODE_SMP ] |
|
|
|
SERIALIZED_VERSION = 3 |
|
|
|
|
|
SERIALIZED_UUID = ADDED_UNICODE_SMP |
|
|
|
class ATNDeserializer (object): |
|
__slots__ = ('deserializationOptions', 'data', 'pos', 'uuid') |
|
|
|
def __init__(self, options : ATNDeserializationOptions = None): |
|
if options is None: |
|
options = ATNDeserializationOptions.defaultOptions |
|
self.deserializationOptions = options |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def isFeatureSupported(self, feature : UUID , actualUuid : UUID ): |
|
idx1 = SUPPORTED_UUIDS.index(feature) |
|
if idx1<0: |
|
return False |
|
idx2 = SUPPORTED_UUIDS.index(actualUuid) |
|
return idx2 >= idx1 |
|
|
|
def deserialize(self, data : str): |
|
self.reset(data) |
|
self.checkVersion() |
|
self.checkUUID() |
|
atn = self.readATN() |
|
self.readStates(atn) |
|
self.readRules(atn) |
|
self.readModes(atn) |
|
sets = [] |
|
|
|
self.readSets(atn, sets, self.readInt) |
|
|
|
|
|
if self.isFeatureSupported(ADDED_UNICODE_SMP, self.uuid): |
|
self.readSets(atn, sets, self.readInt32) |
|
self.readEdges(atn, sets) |
|
self.readDecisions(atn) |
|
self.readLexerActions(atn) |
|
self.markPrecedenceDecisions(atn) |
|
self.verifyATN(atn) |
|
if self.deserializationOptions.generateRuleBypassTransitions \ |
|
and atn.grammarType == ATNType.PARSER: |
|
self.generateRuleBypassTransitions(atn) |
|
|
|
self.verifyATN(atn) |
|
return atn |
|
|
|
def reset(self, data:str): |
|
def adjust(c): |
|
v = ord(c) |
|
return v-2 if v>1 else v + 65533 |
|
temp = [ adjust(c) for c in data ] |
|
|
|
temp[0] = ord(data[0]) |
|
self.data = temp |
|
self.pos = 0 |
|
|
|
def checkVersion(self): |
|
version = self.readInt() |
|
if version != SERIALIZED_VERSION: |
|
raise Exception("Could not deserialize ATN with version " + str(version) + " (expected " + str(SERIALIZED_VERSION) + ").") |
|
|
|
def checkUUID(self): |
|
uuid = self.readUUID() |
|
if not uuid in SUPPORTED_UUIDS: |
|
raise Exception("Could not deserialize ATN with UUID: " + str(uuid) + \ |
|
" (expected " + str(SERIALIZED_UUID) + " or a legacy UUID).", uuid, SERIALIZED_UUID) |
|
self.uuid = uuid |
|
|
|
def readATN(self): |
|
idx = self.readInt() |
|
grammarType = ATNType.fromOrdinal(idx) |
|
maxTokenType = self.readInt() |
|
return ATN(grammarType, maxTokenType) |
|
|
|
def readStates(self, atn:ATN): |
|
loopBackStateNumbers = [] |
|
endStateNumbers = [] |
|
nstates = self.readInt() |
|
for i in range(0, nstates): |
|
stype = self.readInt() |
|
|
|
if stype==ATNState.INVALID_TYPE: |
|
atn.addState(None) |
|
continue |
|
ruleIndex = self.readInt() |
|
if ruleIndex == 0xFFFF: |
|
ruleIndex = -1 |
|
|
|
s = self.stateFactory(stype, ruleIndex) |
|
if stype == ATNState.LOOP_END: |
|
loopBackStateNumber = self.readInt() |
|
loopBackStateNumbers.append((s, loopBackStateNumber)) |
|
elif isinstance(s, BlockStartState): |
|
endStateNumber = self.readInt() |
|
endStateNumbers.append((s, endStateNumber)) |
|
|
|
atn.addState(s) |
|
|
|
|
|
for pair in loopBackStateNumbers: |
|
pair[0].loopBackState = atn.states[pair[1]] |
|
|
|
for pair in endStateNumbers: |
|
pair[0].endState = atn.states[pair[1]] |
|
|
|
numNonGreedyStates = self.readInt() |
|
for i in range(0, numNonGreedyStates): |
|
stateNumber = self.readInt() |
|
atn.states[stateNumber].nonGreedy = True |
|
|
|
numPrecedenceStates = self.readInt() |
|
for i in range(0, numPrecedenceStates): |
|
stateNumber = self.readInt() |
|
atn.states[stateNumber].isPrecedenceRule = True |
|
|
|
def readRules(self, atn:ATN): |
|
nrules = self.readInt() |
|
if atn.grammarType == ATNType.LEXER: |
|
atn.ruleToTokenType = [0] * nrules |
|
|
|
atn.ruleToStartState = [0] * nrules |
|
for i in range(0, nrules): |
|
s = self.readInt() |
|
startState = atn.states[s] |
|
atn.ruleToStartState[i] = startState |
|
if atn.grammarType == ATNType.LEXER: |
|
tokenType = self.readInt() |
|
if tokenType == 0xFFFF: |
|
tokenType = Token.EOF |
|
|
|
atn.ruleToTokenType[i] = tokenType |
|
|
|
atn.ruleToStopState = [0] * nrules |
|
for state in atn.states: |
|
if not isinstance(state, RuleStopState): |
|
continue |
|
atn.ruleToStopState[state.ruleIndex] = state |
|
atn.ruleToStartState[state.ruleIndex].stopState = state |
|
|
|
def readModes(self, atn:ATN): |
|
nmodes = self.readInt() |
|
for i in range(0, nmodes): |
|
s = self.readInt() |
|
atn.modeToStartState.append(atn.states[s]) |
|
|
|
def readSets(self, atn:ATN, sets:list, readUnicode:Callable[[], int]): |
|
m = self.readInt() |
|
for i in range(0, m): |
|
iset = IntervalSet() |
|
sets.append(iset) |
|
n = self.readInt() |
|
containsEof = self.readInt() |
|
if containsEof!=0: |
|
iset.addOne(-1) |
|
for j in range(0, n): |
|
i1 = readUnicode() |
|
i2 = readUnicode() |
|
iset.addRange(range(i1, i2 + 1)) |
|
|
|
def readEdges(self, atn:ATN, sets:list): |
|
nedges = self.readInt() |
|
for i in range(0, nedges): |
|
src = self.readInt() |
|
trg = self.readInt() |
|
ttype = self.readInt() |
|
arg1 = self.readInt() |
|
arg2 = self.readInt() |
|
arg3 = self.readInt() |
|
trans = self.edgeFactory(atn, ttype, src, trg, arg1, arg2, arg3, sets) |
|
srcState = atn.states[src] |
|
srcState.addTransition(trans) |
|
|
|
|
|
for state in atn.states: |
|
for i in range(0, len(state.transitions)): |
|
t = state.transitions[i] |
|
if not isinstance(t, RuleTransition): |
|
continue |
|
outermostPrecedenceReturn = -1 |
|
if atn.ruleToStartState[t.target.ruleIndex].isPrecedenceRule: |
|
if t.precedence == 0: |
|
outermostPrecedenceReturn = t.target.ruleIndex |
|
trans = EpsilonTransition(t.followState, outermostPrecedenceReturn) |
|
atn.ruleToStopState[t.target.ruleIndex].addTransition(trans) |
|
|
|
for state in atn.states: |
|
if isinstance(state, BlockStartState): |
|
|
|
if state.endState is None: |
|
raise Exception("IllegalState") |
|
|
|
if state.endState.startState is not None: |
|
raise Exception("IllegalState") |
|
state.endState.startState = state |
|
|
|
if isinstance(state, PlusLoopbackState): |
|
for i in range(0, len(state.transitions)): |
|
target = state.transitions[i].target |
|
if isinstance(target, PlusBlockStartState): |
|
target.loopBackState = state |
|
elif isinstance(state, StarLoopbackState): |
|
for i in range(0, len(state.transitions)): |
|
target = state.transitions[i].target |
|
if isinstance(target, StarLoopEntryState): |
|
target.loopBackState = state |
|
|
|
def readDecisions(self, atn:ATN): |
|
ndecisions = self.readInt() |
|
for i in range(0, ndecisions): |
|
s = self.readInt() |
|
decState = atn.states[s] |
|
atn.decisionToState.append(decState) |
|
decState.decision = i |
|
|
|
def readLexerActions(self, atn:ATN): |
|
if atn.grammarType == ATNType.LEXER: |
|
count = self.readInt() |
|
atn.lexerActions = [ None ] * count |
|
for i in range(0, count): |
|
actionType = self.readInt() |
|
data1 = self.readInt() |
|
if data1 == 0xFFFF: |
|
data1 = -1 |
|
data2 = self.readInt() |
|
if data2 == 0xFFFF: |
|
data2 = -1 |
|
lexerAction = self.lexerActionFactory(actionType, data1, data2) |
|
atn.lexerActions[i] = lexerAction |
|
|
|
def generateRuleBypassTransitions(self, atn:ATN): |
|
|
|
count = len(atn.ruleToStartState) |
|
atn.ruleToTokenType = [ 0 ] * count |
|
for i in range(0, count): |
|
atn.ruleToTokenType[i] = atn.maxTokenType + i + 1 |
|
|
|
for i in range(0, count): |
|
self.generateRuleBypassTransition(atn, i) |
|
|
|
def generateRuleBypassTransition(self, atn:ATN, idx:int): |
|
|
|
bypassStart = BasicBlockStartState() |
|
bypassStart.ruleIndex = idx |
|
atn.addState(bypassStart) |
|
|
|
bypassStop = BlockEndState() |
|
bypassStop.ruleIndex = idx |
|
atn.addState(bypassStop) |
|
|
|
bypassStart.endState = bypassStop |
|
atn.defineDecisionState(bypassStart) |
|
|
|
bypassStop.startState = bypassStart |
|
|
|
excludeTransition = None |
|
|
|
if atn.ruleToStartState[idx].isPrecedenceRule: |
|
|
|
endState = None |
|
for state in atn.states: |
|
if self.stateIsEndStateFor(state, idx): |
|
endState = state |
|
excludeTransition = state.loopBackState.transitions[0] |
|
break |
|
|
|
if excludeTransition is None: |
|
raise Exception("Couldn't identify final state of the precedence rule prefix section.") |
|
|
|
else: |
|
|
|
endState = atn.ruleToStopState[idx] |
|
|
|
|
|
for state in atn.states: |
|
for transition in state.transitions: |
|
if transition == excludeTransition: |
|
continue |
|
if transition.target == endState: |
|
transition.target = bypassStop |
|
|
|
|
|
ruleToStartState = atn.ruleToStartState[idx] |
|
count = len(ruleToStartState.transitions) |
|
while count > 0: |
|
bypassStart.addTransition(ruleToStartState.transitions[count-1]) |
|
del ruleToStartState.transitions[-1] |
|
|
|
|
|
atn.ruleToStartState[idx].addTransition(EpsilonTransition(bypassStart)) |
|
bypassStop.addTransition(EpsilonTransition(endState)) |
|
|
|
matchState = BasicState() |
|
atn.addState(matchState) |
|
matchState.addTransition(AtomTransition(bypassStop, atn.ruleToTokenType[idx])) |
|
bypassStart.addTransition(EpsilonTransition(matchState)) |
|
|
|
|
|
def stateIsEndStateFor(self, state:ATNState, idx:int): |
|
if state.ruleIndex != idx: |
|
return None |
|
if not isinstance(state, StarLoopEntryState): |
|
return None |
|
|
|
maybeLoopEndState = state.transitions[len(state.transitions) - 1].target |
|
if not isinstance(maybeLoopEndState, LoopEndState): |
|
return None |
|
|
|
if maybeLoopEndState.epsilonOnlyTransitions and \ |
|
isinstance(maybeLoopEndState.transitions[0].target, RuleStopState): |
|
return state |
|
else: |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def markPrecedenceDecisions(self, atn:ATN): |
|
for state in atn.states: |
|
if not isinstance(state, StarLoopEntryState): |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
if atn.ruleToStartState[state.ruleIndex].isPrecedenceRule: |
|
maybeLoopEndState = state.transitions[len(state.transitions) - 1].target |
|
if isinstance(maybeLoopEndState, LoopEndState): |
|
if maybeLoopEndState.epsilonOnlyTransitions and \ |
|
isinstance(maybeLoopEndState.transitions[0].target, RuleStopState): |
|
state.isPrecedenceDecision = True |
|
|
|
def verifyATN(self, atn:ATN): |
|
if not self.deserializationOptions.verifyATN: |
|
return |
|
|
|
for state in atn.states: |
|
if state is None: |
|
continue |
|
|
|
self.checkCondition(state.epsilonOnlyTransitions or len(state.transitions) <= 1) |
|
|
|
if isinstance(state, PlusBlockStartState): |
|
self.checkCondition(state.loopBackState is not None) |
|
|
|
if isinstance(state, StarLoopEntryState): |
|
self.checkCondition(state.loopBackState is not None) |
|
self.checkCondition(len(state.transitions) == 2) |
|
|
|
if isinstance(state.transitions[0].target, StarBlockStartState): |
|
self.checkCondition(isinstance(state.transitions[1].target, LoopEndState)) |
|
self.checkCondition(not state.nonGreedy) |
|
elif isinstance(state.transitions[0].target, LoopEndState): |
|
self.checkCondition(isinstance(state.transitions[1].target, StarBlockStartState)) |
|
self.checkCondition(state.nonGreedy) |
|
else: |
|
raise Exception("IllegalState") |
|
|
|
if isinstance(state, StarLoopbackState): |
|
self.checkCondition(len(state.transitions) == 1) |
|
self.checkCondition(isinstance(state.transitions[0].target, StarLoopEntryState)) |
|
|
|
if isinstance(state, LoopEndState): |
|
self.checkCondition(state.loopBackState is not None) |
|
|
|
if isinstance(state, RuleStartState): |
|
self.checkCondition(state.stopState is not None) |
|
|
|
if isinstance(state, BlockStartState): |
|
self.checkCondition(state.endState is not None) |
|
|
|
if isinstance(state, BlockEndState): |
|
self.checkCondition(state.startState is not None) |
|
|
|
if isinstance(state, DecisionState): |
|
self.checkCondition(len(state.transitions) <= 1 or state.decision >= 0) |
|
else: |
|
self.checkCondition(len(state.transitions) <= 1 or isinstance(state, RuleStopState)) |
|
|
|
def checkCondition(self, condition:bool, message=None): |
|
if not condition: |
|
if message is None: |
|
message = "IllegalState" |
|
raise Exception(message) |
|
|
|
def readInt(self): |
|
i = self.data[self.pos] |
|
self.pos += 1 |
|
return i |
|
|
|
def readInt32(self): |
|
low = self.readInt() |
|
high = self.readInt() |
|
return low | (high << 16) |
|
|
|
def readLong(self): |
|
low = self.readInt32() |
|
high = self.readInt32() |
|
return (low & 0x00000000FFFFFFFF) | (high << 32) |
|
|
|
def readUUID(self): |
|
low = self.readLong() |
|
high = self.readLong() |
|
allBits = (low & 0xFFFFFFFFFFFFFFFF) | (high << 64) |
|
return UUID(int=allBits) |
|
|
|
edgeFactories = [ lambda args : None, |
|
lambda atn, src, trg, arg1, arg2, arg3, sets, target : EpsilonTransition(target), |
|
lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ |
|
RangeTransition(target, Token.EOF, arg2) if arg3 != 0 else RangeTransition(target, arg1, arg2), |
|
lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ |
|
RuleTransition(atn.states[arg1], arg2, arg3, target), |
|
lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ |
|
PredicateTransition(target, arg1, arg2, arg3 != 0), |
|
lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ |
|
AtomTransition(target, Token.EOF) if arg3 != 0 else AtomTransition(target, arg1), |
|
lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ |
|
ActionTransition(target, arg1, arg2, arg3 != 0), |
|
lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ |
|
SetTransition(target, sets[arg1]), |
|
lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ |
|
NotSetTransition(target, sets[arg1]), |
|
lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ |
|
WildcardTransition(target), |
|
lambda atn, src, trg, arg1, arg2, arg3, sets, target : \ |
|
PrecedencePredicateTransition(target, arg1) |
|
] |
|
|
|
def edgeFactory(self, atn:ATN, type:int, src:int, trg:int, arg1:int, arg2:int, arg3:int, sets:list): |
|
target = atn.states[trg] |
|
if type > len(self.edgeFactories) or self.edgeFactories[type] is None: |
|
raise Exception("The specified transition type: " + str(type) + " is not valid.") |
|
else: |
|
return self.edgeFactories[type](atn, src, trg, arg1, arg2, arg3, sets, target) |
|
|
|
stateFactories = [ lambda : None, |
|
lambda : BasicState(), |
|
lambda : RuleStartState(), |
|
lambda : BasicBlockStartState(), |
|
lambda : PlusBlockStartState(), |
|
lambda : StarBlockStartState(), |
|
lambda : TokensStartState(), |
|
lambda : RuleStopState(), |
|
lambda : BlockEndState(), |
|
lambda : StarLoopbackState(), |
|
lambda : StarLoopEntryState(), |
|
lambda : PlusLoopbackState(), |
|
lambda : LoopEndState() |
|
] |
|
|
|
def stateFactory(self, type:int, ruleIndex:int): |
|
if type> len(self.stateFactories) or self.stateFactories[type] is None: |
|
raise Exception("The specified state type " + str(type) + " is not valid.") |
|
else: |
|
s = self.stateFactories[type]() |
|
if s is not None: |
|
s.ruleIndex = ruleIndex |
|
return s |
|
|
|
CHANNEL = 0 |
|
CUSTOM = 1 |
|
MODE = 2 |
|
MORE = 3 |
|
POP_MODE = 4 |
|
PUSH_MODE = 5 |
|
SKIP = 6 |
|
TYPE = 7 |
|
|
|
actionFactories = [ lambda data1, data2: LexerChannelAction(data1), |
|
lambda data1, data2: LexerCustomAction(data1, data2), |
|
lambda data1, data2: LexerModeAction(data1), |
|
lambda data1, data2: LexerMoreAction.INSTANCE, |
|
lambda data1, data2: LexerPopModeAction.INSTANCE, |
|
lambda data1, data2: LexerPushModeAction(data1), |
|
lambda data1, data2: LexerSkipAction.INSTANCE, |
|
lambda data1, data2: LexerTypeAction(data1) |
|
] |
|
|
|
def lexerActionFactory(self, type:int, data1:int, data2:int): |
|
|
|
if type > len(self.actionFactories) or self.actionFactories[type] is None: |
|
raise Exception("The specified lexer action type " + str(type) + " is not valid.") |
|
else: |
|
return self.actionFactories[type](data1, data2) |
|
|