File size: 1,777 Bytes
1307964
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
const express = require('express');
const { jsonParser } = require('../express-common');

const TASK = 'text-classification';

const router = express.Router();

/**
 * @type {Map<string, object>} Cache for classification results
 */
const cacheObject = new Map();

router.post('/labels', jsonParser, async (req, res) => {
    try {
        const module = await import('../transformers.mjs');
        const pipe = await module.default.getPipeline(TASK);
        const result = Object.keys(pipe.model.config.label2id);
        return res.json({ labels: result });
    } catch (error) {
        console.error(error);
        return res.sendStatus(500);
    }
});

router.post('/', jsonParser, async (req, res) => {
    try {
        const { text } = req.body;

        /**
         * Get classification result for a given text
         * @param {string} text Text to classify
         * @returns {Promise<object>} Classification result
         */
        async function getResult(text) {
            if (cacheObject.has(text)) {
                return cacheObject.get(text);
            } else {
                const module = await import('../transformers.mjs');
                const pipe = await module.default.getPipeline(TASK);
                const result = await pipe(text, { topk: 5 });
                result.sort((a, b) => b.score - a.score);
                cacheObject.set(text, result);
                return result;
            }
        }

        console.log('Classify input:', text);
        const result = await getResult(text);
        console.log('Classify output:', result);

        return res.json({ classification: result });
    } catch (error) {
        console.error(error);
        return res.sendStatus(500);
    }
});

module.exports = { router };