Spaces:
Runtime error
Runtime error
File size: 4,594 Bytes
acc4ffe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
"""Test OpenAI API wrapper."""
from pathlib import Path
from typing import Generator
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI
from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_openai_call() -> None:
"""Test valid call to openai."""
llm = OpenAI(max_tokens=10)
output = llm("Say foo:")
assert isinstance(output, str)
def test_openai_extra_kwargs() -> None:
"""Test extra kwargs to openai."""
# Check that foo is saved in extra_kwargs.
llm = OpenAI(foo=3, max_tokens=10)
assert llm.max_tokens == 10
assert llm.model_kwargs == {"foo": 3}
# Test that if extra_kwargs are provided, they are added to it.
llm = OpenAI(foo=3, model_kwargs={"bar": 2})
assert llm.model_kwargs == {"foo": 3, "bar": 2}
# Test that if provided twice it errors
with pytest.raises(ValueError):
OpenAI(foo=3, model_kwargs={"foo": 2})
def test_openai_stop_valid() -> None:
"""Test openai stop logic on valid configuration."""
query = "write an ordered list of five items"
first_llm = OpenAI(stop="3", temperature=0)
first_output = first_llm(query)
second_llm = OpenAI(temperature=0)
second_output = second_llm(query, stop=["3"])
# Because it stops on new lines, shouldn't return anything
assert first_output == second_output
def test_openai_stop_error() -> None:
"""Test openai stop logic on bad configuration."""
llm = OpenAI(stop="3", temperature=0)
with pytest.raises(ValueError):
llm("write an ordered list of five items", stop=["\n"])
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an OpenAPI LLM."""
llm = OpenAI(max_tokens=10)
llm.save(file_path=tmp_path / "openai.yaml")
loaded_llm = load_llm(tmp_path / "openai.yaml")
assert loaded_llm == llm
def test_openai_streaming() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
generator = llm.stream("I'm Pickle Rick")
assert isinstance(generator, Generator)
for token in generator:
assert isinstance(token["choices"][0]["text"], str)
def test_openai_streaming_error() -> None:
"""Test error handling in stream."""
llm = OpenAI(best_of=2)
with pytest.raises(ValueError):
llm.stream("I'm Pickle Rick")
def test_openai_streaming_best_of_error() -> None:
"""Test validation for streaming fails if best_of is not 1."""
with pytest.raises(ValueError):
OpenAI(best_of=2, streaming=True)
def test_openai_streaming_n_error() -> None:
"""Test validation for streaming fails if n is not 1."""
with pytest.raises(ValueError):
OpenAI(n=2, streaming=True)
def test_openai_streaming_multiple_prompts_error() -> None:
"""Test validation for streaming fails if multiple prompts are given."""
with pytest.raises(ValueError):
OpenAI(streaming=True).generate(["I'm Pickle Rick", "I'm Pickle Rick"])
def test_openai_streaming_call() -> None:
"""Test valid call to openai."""
llm = OpenAI(max_tokens=10, streaming=True)
output = llm("Say foo:")
assert isinstance(output, str)
def test_openai_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = OpenAI(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
llm("Write me a sentence with 100 words.")
assert callback_handler.llm_streams == 10
@pytest.mark.asyncio
async def test_openai_async_generate() -> None:
"""Test async generation."""
llm = OpenAI(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)
@pytest.mark.asyncio
async def test_openai_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = OpenAI(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
result = await llm.agenerate(["Write me a sentence with 100 words."])
assert callback_handler.llm_streams == 10
assert isinstance(result, LLMResult)
|