pvanand commited on
Commit
97c889b
·
verified ·
1 Parent(s): 2605231

Update aws_aiclient.py

Browse files
Files changed (1) hide show
  1. aws_aiclient.py +12 -9
aws_aiclient.py CHANGED
@@ -1,5 +1,4 @@
1
  # aws_aiclient.py
2
-
3
  import os
4
  import time
5
  import json
@@ -38,14 +37,18 @@ text_models = {
38
 
39
  class AIClient:
40
  def __init__(self):
41
- self.client = ChatBedrockConverse(
42
- model='meta.llama3-70b-instruct-v1:0', # default model
 
 
 
 
 
 
43
  region_name="ap-south-1",
44
  aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
45
  aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY")
46
  )
47
- self.observability_manager = LLMObservabilityManager()
48
- self.models = text_models
49
 
50
  async def generate_response(
51
  self,
@@ -64,13 +67,13 @@ class AIClient:
64
  status = "success"
65
 
66
  try:
67
- # Update the client's model if different from current
68
- if model != self.client.model:
69
- self.client.model = model
 
70
 
71
  # Stream the response
72
  async for chunk in self.client.astream(messages):
73
- print(chunk)
74
  if chunk.content and chunk.content[0].get("text"):
75
  content = chunk.content[0].get("text")
76
  yield content
 
1
  # aws_aiclient.py
 
2
  import os
3
  import time
4
  import json
 
37
 
38
  class AIClient:
39
  def __init__(self):
40
+ self._model = 'meta.llama3-70b-instruct-v1:0' # default model
41
+ self.client = self._create_client()
42
+ self.observability_manager = LLMObservabilityManager()
43
+ self.models = text_models
44
+
45
+ def _create_client(self, model: str = None) -> ChatBedrockConverse:
46
+ return ChatBedrockConverse(
47
+ model=model or self._model,
48
  region_name="ap-south-1",
49
  aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
50
  aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY")
51
  )
 
 
52
 
53
  async def generate_response(
54
  self,
 
67
  status = "success"
68
 
69
  try:
70
+ # Update the client if model is different
71
+ if model != self._model:
72
+ self._model = model
73
+ self.client = self._create_client()
74
 
75
  # Stream the response
76
  async for chunk in self.client.astream(messages):
 
77
  if chunk.content and chunk.content[0].get("text"):
78
  content = chunk.content[0].get("text")
79
  yield content