File size: 1,802 Bytes
cbb7bf7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56


def test_locally(load_config, setup_logger, LLMApi):
    """Run local tests for development and debugging"""
    config = load_config()
    logger = setup_logger(config, "test")
    logger.info("Starting local tests")

    api = LLMApi(config)
    model_name = config["model"]["defaults"]["model_name"]

    logger.info(f"Testing with model: {model_name}")

    # Test download
    logger.info("Testing model download...")
    api.download_model(model_name)
    logger.info("Download complete")

    # Test initialization
    logger.info("Initializing model...")
    api.initialize_model(model_name)
    logger.info("Model initialized")

    # Test embedding
    test_text = "Dette er en test av embeddings generering fra en teknisk tekst om HMS rutiner på arbeidsplassen."
    logger.info("Testing embedding generation...")
    embedding = api.generate_embedding(test_text)
    logger.info(f"Generated embedding of length: {len(embedding)}")
    logger.info(f"First few values: {embedding[:5]}")

    # Test generation
    test_prompts = [
        "Tell me what happens in a nuclear reactor.",
    ]

    # Test regular generation
    logger.info("Testing regular generation:")
    for prompt in test_prompts:
        logger.info(f"Prompt: {prompt}")
        response = api.generate_response(
            prompt=prompt,
            system_message="You are a helpful assistant."
        )
        logger.info(f"Response: {response}")

    # Test streaming generation
    logger.info("Testing streaming generation:")
    logger.info(f"Prompt: {test_prompts[0]}")
    for chunk in api.generate_stream(
            prompt=test_prompts[0],
            system_message="You are a helpful assistant."
    ):
        print(chunk, end="", flush=True)
    print("\n")

    logger.info("Local tests completed")