Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 1,866 Bytes
63858e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import * as d3 from 'd3'
import 'd3-array'
import * as au from '../etc/arrayUtils'
import * as tf from '@tensorflow/tfjs'
import { TypedArray } from '@tensorflow/tfjs-core/dist/types';
export interface Edge {
i: number, // Source index
j: number, // Target index
v: number, // Value
}
/**
* Convert data matrix to necessary data array to pass to SVG connections
*/
export function toEdges (data:number[][], cutoffAmt=1) : Edge[] {
let outArr: Edge[] = [];
let cutoff: number;
data.forEach((row, i) => {
cutoff = cutoffAmt * d3.sum(row);
let counter = 0;
const sortedArr:au.SortArray = au.sortWithIndices(row);
sortedArr.arr.forEach((v,j) => {
if (counter < cutoff) {
const obj: Edge = {
i: i,
j: sortedArr.sortIndices[j],
v: v,
}
outArr.push(obj);
counter += v;
}
})
})
return outArr;
}
/**
* Class for implementing operations on AttentionGraph implementation.
* Closely tied to [[AttentionConnector]]
*/
export class EdgeData {
readonly tensData:tf.Tensor;
constructor (public data:number[][]){
this.tensData = tf.tensor(data);
}
min(axis?:number):TypedArray {
return this.tensData.min(axis).dataSync();
}
max(axis?:number):TypedArray{
return this.tensData.max(axis).dataSync();
}
extent(axis?:number):number[][] {
return d3.zip(this.min(axis), this.max(axis))
}
/**
* Format the data to send to SVG chart.
*
* @param accumulateThresh - A float between 0 and 1, indicating the amount of weight to display. Defaults to 0.7.
*/
format (accumulateThresh=0.7):Edge[] {
return toEdges(this.data, accumulateThresh);
}
} |