File size: 395 Bytes
e75cc88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import spacy

def set_threshold(model_path, threshold):
    # Load the trained model
    nlp = spacy.load(model_path)

    # Set the threshold for text classification
    nlp.get_pipe('textcat_multilabel').threshold = threshold

    return nlp

# Example usage:
if __name__ == "__main__":
    model_path = "./my_trained_model"
    threshold = 0.21
    nlp = set_threshold(model_path, threshold)