Eric Botti commited on
Commit
72e4e68
·
1 Parent(s): ea658a2

added ollama engine

Browse files
Files changed (1) hide show
  1. src/player.py +23 -17
src/player.py CHANGED
@@ -9,21 +9,27 @@ from kani.engines.openai import OpenAIEngine
9
 
10
  from game_utils import log
11
 
12
- # Using TGI Inference Endpoints from Hugging Face
13
  # api_type = "tgi"
14
- api_type = "openai"
15
-
16
- if api_type == "tgi":
17
- model_name = "tgi"
18
- client = openai.Client(
19
- base_url=os.environ['HF_ENDPOINT_URL'] + "/v1/",
20
- api_key=os.environ['HF_API_TOKEN']
21
- )
22
- else:
23
- model_name = "gpt-3.5-turbo"
24
- client = openai.Client()
25
-
26
- openai_engine = OpenAIEngine(model="gpt-3.5-turbo")
 
 
 
 
 
 
 
27
 
28
 
29
  class Player:
@@ -32,7 +38,7 @@ class Player:
32
  self.id = id
33
  self.controller = controller_type
34
  if controller_type == "ai":
35
- self.kani = LogMessagesKani(openai_engine, log_filepath=log_filepath)
36
 
37
  self.role = role
38
  self.messages = []
@@ -52,10 +58,10 @@ class Player:
52
  """Makes the player respond to a prompt. Returns the response."""
53
  if self.controller == "human":
54
  # We're pretending the human is an ai for logging purposes... I don't love this but it's fine for now
55
- log(ChatMessage.user(prompt), self.log_filepath)
56
  print(prompt)
57
  output = input()
58
- log(ChatMessage.ai(output), self.log_filepath)
59
 
60
  return output
61
 
 
9
 
10
  from game_utils import log
11
 
 
12
  # api_type = "tgi"
13
+ # api_type = "openai"
14
+ api_type = "ollama"
15
+
16
+ match api_type:
17
+ case "tgi":
18
+ # Using TGI Inference Endpoints from Hugging Face
19
+ default_engine = OpenAIEngine( # type: ignore
20
+ api_base=os.environ['HF_ENDPOINT_URL'] + "/v1/",
21
+ api_key=os.environ['HF_API_TOKEN']
22
+ )
23
+ case "openai":
24
+ # Using OpenAI GPT-3.5 Turbo
25
+ default_engine = OpenAIEngine(model="gpt-3.5-turbo") # type: ignore
26
+ case "ollama":
27
+ # Using Ollama
28
+ default_engine = OpenAIEngine(
29
+ api_base="http://localhost:11434/v1",
30
+ api_key="ollama",
31
+ model="mistral"
32
+ )
33
 
34
 
35
  class Player:
 
38
  self.id = id
39
  self.controller = controller_type
40
  if controller_type == "ai":
41
+ self.kani = LogMessagesKani(default_engine, log_filepath=log_filepath)
42
 
43
  self.role = role
44
  self.messages = []
 
58
  """Makes the player respond to a prompt. Returns the response."""
59
  if self.controller == "human":
60
  # We're pretending the human is an ai for logging purposes... I don't love this but it's fine for now
61
+ log(ChatMessage.user(prompt).model_dump_json(), self.log_filepath)
62
  print(prompt)
63
  output = input()
64
+ log(ChatMessage.assistant(output).model_dump_json(), self.log_filepath)
65
 
66
  return output
67