File size: 5,571 Bytes
7368ee6 c698696 ca685d7 7368ee6 ca685d7 7368ee6 ca685d7 7368ee6 ca685d7 7368ee6 ca685d7 7368ee6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
using System;
using System.Collections.Generic;
using System.Linq;
using Unity.Sentis;
using UnityEngine;
public sealed class DebertaV3 : MonoBehaviour
{
public ModelAsset model;
public TextAsset vocabulary;
public bool multipleTrueClasses;
public string text = "Angela Merkel is a politician in Germany and leader of the CDU";
public string hypothesisTemplate = "This example is about {}";
public string[] classes = { "politics", "economy", "entertainment", "environment" };
IWorker engine;
string[] vocabularyTokens;
const int padToken = 0;
const int startToken = 1;
const int separatorToken = 2;
const int vocabToTokenOffset = 260;
void Start()
{
if (classes.Length == 0)
{
Debug.LogError("There need to be more than 0 classes");
return;
}
vocabularyTokens = vocabulary.text.Replace("\r", "").Split("\n");
Model baseModel = ModelLoader.Load(model);
Model modelWithScoring = Functional.Compile(
input =>
{
// The logits represent the model's predictions for entailment and non-entailment for each example in the batch.
// They are of shape [batch size, 2] i.e. with two values per example.
// To obtain a single score per example, a softmax function is applied
FunctionalTensor logits = baseModel.Forward(input)[0];
if (multipleTrueClasses || classes.Length == 1)
{
// Softmax over the entailment vs. contradiction dimension for each label independently
logits = Functional.Softmax(logits);
}
else
{
// Softmax over all candidate labels
logits = Functional.Softmax(logits, 0);
}
// The scores are stored along the first column
return new []{logits[.., 0]};
},
InputDef.FromModel(baseModel)
);
engine = WorkerFactory.CreateWorker(BackendType.GPUCompute, modelWithScoring);
string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
Batch batch = GetTokenizedBatch(text, hypotheses);
float[] scores = GetBatchScores(batch);
for (int i = 0; i < scores.Length; i++)
{
Debug.Log($"[{classes[i]}] Entailment Score: {scores[i]}");
}
}
float[] GetBatchScores(Batch batch)
{
using var inputIds = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedTokens);
using var attentionMask = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedMasks);
Dictionary<string, Tensor> inputs = new()
{
{"input_0", inputIds},
{"input_1", attentionMask}
};
engine.Execute(inputs);
TensorFloat scores = (TensorFloat)engine.PeekOutput("output_0");
scores.CompleteOperationsAndDownload();
return scores.ToReadOnlyArray();
}
Batch GetTokenizedBatch(string prompt, string[] hypotheses)
{
Batch batch = new Batch();
List<int> promptTokens = Tokenize(prompt);
promptTokens.Insert(0, startToken);
List<int>[] tokenizedHypotheses = hypotheses.Select(Tokenize).ToArray();
int maxTokenLength = tokenizedHypotheses.Max(x => x.Count);
// Each example in the batch follows this format:
// Start Prompt Separator Hypothesis Separator Padding
int[] batchedTokens = tokenizedHypotheses.SelectMany(hypothesis => promptTokens
.Append(separatorToken)
.Concat(hypothesis)
.Append(separatorToken)
.Concat(Enumerable.Repeat(padToken, maxTokenLength - hypothesis.Count)))
.ToArray();
// The attention masks have the same length as the tokens.
// Each attention mask contains repeating 1s for each token, except for padding tokens.
int[] batchedMasks = tokenizedHypotheses.SelectMany(hypothesis => Enumerable.Repeat(1, promptTokens.Count + 1)
.Concat(Enumerable.Repeat(1, hypothesis.Count + 1))
.Concat(Enumerable.Repeat(0, maxTokenLength - hypothesis.Count)))
.ToArray();
batch.BatchCount = hypotheses.Length;
batch.BatchLength = batchedTokens.Length / hypotheses.Length;
batch.BatchedTokens = batchedTokens;
batch.BatchedMasks = batchedMasks;
return batch;
}
List<int> Tokenize(string input)
{
string[] words = input.Split(null);
List<int> ids = new();
foreach (string word in words)
{
int start = 0;
for(int i = word.Length; i >= 0;i--)
{
string subWord = start == 0 ? "▁" + word.Substring(start, i) : word.Substring(start, i-start);
int index = Array.IndexOf(vocabularyTokens, subWord);
if (index >= 0)
{
ids.Add(index + vocabToTokenOffset);
if (i == word.Length) break;
start = i;
i = word.Length + 1;
}
}
}
return ids;
}
void OnDestroy() => engine?.Dispose();
struct Batch
{
public int BatchCount;
public int BatchLength;
public int[] BatchedTokens;
public int[] BatchedMasks;
}
} |