Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import * as x_ from '../etc/_Tools' | |
import * as _ from 'lodash' | |
import * as tp from '../etc/types' | |
import * as R from 'ramda' | |
/** | |
* The original tokens, and the indexes that need to be masked | |
*/ | |
const emptyFullResponse: tp.FullSingleTokenInfo[] = [{ | |
text: '[SEP]', | |
topk_words: [], | |
topk_probs: [] | |
}] | |
export class TokenDisplay { | |
tokenData:tp.FullSingleTokenInfo[] | |
maskInds:number[] | |
constructor(tokens=emptyFullResponse, maskInds=[]){ | |
this.tokenData = tokens; | |
this.maskInds = maskInds; | |
} | |
/** | |
* Push idx to the mask idx list in order from smallest to largest | |
*/ | |
mask(val) { | |
const currInd = _.indexOf(this.maskInds, val) | |
if (currInd == -1) { | |
x_.orderedInsert_(this.maskInds, val) | |
} | |
else { | |
console.log(`${val} already in maskInds!`); | |
console.log(this.maskInds); | |
} | |
} | |
toggle(val) { | |
const currInd = _.indexOf(this.maskInds, val) | |
if (currInd == -1) { | |
console.log(`Masking ${val}`); | |
this.mask(val) | |
} | |
else { | |
console.log(`Unmasking ${val}`); | |
this.unmask(val) | |
} | |
} | |
unmask(val) { | |
_.pull(this.maskInds, val); | |
} | |
resetMask() { | |
this.maskInds = []; | |
} | |
length() { | |
return this.tokenData.length; | |
} | |
concat(other: TokenDisplay) { | |
const newTokens = _.concat(this.tokenData, other.tokenData); | |
const newMask = _.concat(this.maskInds, other.maskInds.map(x => x + this.length())); | |
return new TokenDisplay(newTokens, newMask); | |
} | |
} | |
export class TokenWrapper { | |
a: TokenDisplay | |
constructor(r:tp.AttentionResponse){ | |
this.updateFromResponse(r); | |
} | |
updateFromResponse(r:tp.AttentionResponse) { | |
const tokensA = r.aa.left; | |
this.updateFromComponents(tokensA, []) | |
} | |
updateFromComponents(a:tp.FullSingleTokenInfo[], maskA:number[]){ | |
this.a = new TokenDisplay(a, maskA) | |
} | |
updateTokens(r: tp.AttentionResponse) { | |
const desiredKeys = ['contexts', 'embeddings', 'topk_probs', 'topk_words'] | |
const newTokens = r.aa.left.map(v => R.pick(desiredKeys, v)) | |
const pairs = R.zip(this.a.tokenData, newTokens) | |
pairs.forEach((d, i) => { | |
Object.keys(d[1]).map(k => { | |
d[0][k] = d[1][k] | |
}) | |
}) | |
} | |
/** | |
* Mask the appropriate sentence at the index indicated | |
*/ | |
mask(sID:tp.TokenOptions, idx:number){ | |
this[sID].mask(idx) | |
const opts = ["a", "b"] | |
const Na = this.a.length(); | |
} | |
} | |
export function sideToLetter(side:tp.SideOptions, atype:tp.SentenceOptions){ | |
// const atype = conf.attType; | |
if (atype == "all") { | |
return "all" | |
} | |
const out = side == "left" ? atype[0] : atype[1] // No type checking? | |
return out | |
} | |