Luigi commited on
Commit
ef8c7e3
·
1 Parent(s): b1d1505

performance improvement: initialize llm agent only once after startup

Browse files
classifier/classifier.py CHANGED
@@ -1,69 +1,79 @@
1
  import logging
2
- from fastapi import FastAPI
3
- from pydantic import BaseModel
4
- from typing import List
5
- import os
6
- from string import Formatter
7
-
8
  import os
 
 
 
9
 
10
  import outlines
11
  from outlines.models import openai
12
  from outlines.generate import choice
13
 
14
  # Configure logger
15
- tools = logging.getLogger("classifier")
16
- tools.setLevel(logging.DEBUG)
17
- ch = logging.StreamHandler()
18
- ch.setLevel(logging.DEBUG)
19
- formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s")
20
- ch.setFormatter(formatter)
21
- tools.addHandler(ch)
22
-
23
- # Configure logger
24
- logging.basicConfig(
25
- format="%(asctime)s %(levelname)s:%(name)s: %(message)s",
26
- level=logging.DEBUG,
27
- )
28
  logger = logging.getLogger("classifier")
29
 
30
  app = FastAPI()
31
 
32
- # Pydantic model for incoming requests; prompt_template added
33
- class Req(BaseModel):
34
- message: str
 
 
35
  model_name: str
36
  base_url: str
37
  class_set: List[str]
38
- prompt_template: str # template with {message} placeholder
 
 
 
39
 
40
  class Resp(BaseModel):
41
  result: str
42
 
43
- # Helper for safe formatting of {message} only
44
- class SafeFormatDict(dict):
45
- def __missing__(self, key):
46
- return '{' + key + '}'
47
 
48
- @app.post("/classify", response_model=Resp)
49
- def classify(req: Req):
50
- logger.debug(f"Received request args: {req.dict()}")
51
-
52
- prompt = req.prompt_template.replace("{message}", req.message)
53
- logger.debug(f"Rendered prompt: {prompt!r}")
54
 
55
  api_key = os.getenv("TOGETHERAI_API_KEY")
56
  logger.debug(f"Using API_KEY: {'set' if api_key else 'missing'}")
57
- llm = openai(req.model_name, api_key=api_key, base_url=req.base_url)
58
- clf = choice(llm, req.class_set)
59
- logger.debug(f"Choice classifier created with labels: {req.class_set}")
60
 
61
  try:
62
- result = clf(prompt)
63
- # If it's a coroutine, run it; otherwise use result
64
- logger.debug(f"Classifier returned: {result}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  except Exception as e:
66
- result = req.class_set[-1]
67
- logger.error(f"Classification error: {e}. Falling back to: {result}")
68
 
69
  return Resp(result=result)
 
1
  import logging
 
 
 
 
 
 
2
  import os
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ from typing import List, Optional
6
 
7
  import outlines
8
  from outlines.models import openai
9
  from outlines.generate import choice
10
 
11
  # Configure logger
12
+ logging.basicConfig(level=logging.DEBUG)
 
 
 
 
 
 
 
 
 
 
 
 
13
  logger = logging.getLogger("classifier")
14
 
15
  app = FastAPI()
16
 
17
+ # Global variables for shared config and classifier
18
+ clf = None
19
+ config_set = False
20
+
21
+ class Config(BaseModel):
22
  model_name: str
23
  base_url: str
24
  class_set: List[str]
25
+ prompt_template: str
26
+
27
+ class Req(BaseModel):
28
+ message: str
29
 
30
  class Resp(BaseModel):
31
  result: str
32
 
33
+ @app.post("/config")
34
+ def configure(req: Config):
35
+ """Receive and initialize classifier configuration."""
36
+ global clf, config_set
37
 
38
+ if config_set:
39
+ logger.warning("Classifier already configured. Ignoring new config.")
40
+ return {"status": "already_configured"}
 
 
 
41
 
42
  api_key = os.getenv("TOGETHERAI_API_KEY")
43
  logger.debug(f"Using API_KEY: {'set' if api_key else 'missing'}")
 
 
 
44
 
45
  try:
46
+ llm = openai(req.model_name, api_key=api_key, base_url=req.base_url)
47
+ clf = choice(llm, req.class_set)
48
+ clf.class_set = req.class_set
49
+ clf.prompt_template = req.prompt_template
50
+ config_set = True
51
+ logger.info("Classifier configured successfully.")
52
+ return {"status": "configured"}
53
+ except Exception as e:
54
+ logger.error(f"Failed to configure classifier: {e}")
55
+ raise HTTPException(status_code=500, detail="Classifier configuration failed")
56
+
57
+ @app.post("/classify", response_model=Resp)
58
+ def classify(req: Req):
59
+ global clf
60
+ if clf is None or not config_set:
61
+ raise HTTPException(status_code=503, detail="Classifier not configured yet")
62
+
63
+ # Render the prompt using the template
64
+ try:
65
+ prompt = clf.prompt_template.replace("{message}", req.message)
66
+ logger.debug(f"Rendered prompt: {prompt!r}")
67
+ except Exception as e:
68
+ logger.warning(f"Prompt rendering failed: {e}")
69
+ prompt = req.message
70
+
71
+ # Run classifier
72
+ try:
73
+ result = clf(prompt)
74
+ logger.debug(f"Classification result: {result}")
75
  except Exception as e:
76
+ logger.error(f"Classification error: {e}. Falling back to: {clf.class_set[-1]}")
77
+ result = clf.class_set[-1]
78
 
79
  return Resp(result=result)
custom_components/llm_intent_classifier_client.py CHANGED
@@ -30,6 +30,7 @@ class LlmIntentClassifier(IntentClassifier):
30
  component_config: Optional[Dict[Text, Any]] = None,
31
  ) -> None:
32
  super().__init__(component_config or {})
 
33
  self.url: str = self.component_config.get("classifier_url")
34
  self.timeout: float = float(self.component_config.get("timeout"))
35
  self.model_name: Optional[Text] = self.component_config.get("model_name")
@@ -52,6 +53,26 @@ class LlmIntentClassifier(IntentClassifier):
52
  f"Missing configuration for {', '.join(missing)} in LlmIntentClassifier"
53
  )
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def train(
56
  self,
57
  training_data: TrainingData,
@@ -67,13 +88,7 @@ class LlmIntentClassifier(IntentClassifier):
67
  confidence: float = 0.0
68
 
69
  if text:
70
- payload: Dict[str, Any] = {
71
- "message": text,
72
- "model_name": self.model_name,
73
- "base_url": self.base_url,
74
- "class_set": self.class_set,
75
- "prompt_template": self.prompt_template,
76
- }
77
  try:
78
  resp = requests.post(self.url, json=payload, timeout=self.timeout)
79
  resp.raise_for_status()
 
30
  component_config: Optional[Dict[Text, Any]] = None,
31
  ) -> None:
32
  super().__init__(component_config or {})
33
+
34
  self.url: str = self.component_config.get("classifier_url")
35
  self.timeout: float = float(self.component_config.get("timeout"))
36
  self.model_name: Optional[Text] = self.component_config.get("model_name")
 
53
  f"Missing configuration for {', '.join(missing)} in LlmIntentClassifier"
54
  )
55
 
56
+ # Push config to classifier backend
57
+ self._configure_remote_classifier()
58
+
59
+ def _configure_remote_classifier(self) -> None:
60
+ """Send configuration to the classifier backend to initialize the model."""
61
+ payload = {
62
+ "model_name": self.model_name,
63
+ "base_url": self.base_url,
64
+ "class_set": self.class_set,
65
+ "prompt_template": self.prompt_template,
66
+ }
67
+ try:
68
+ config_url = self.url.replace("/classify", "/config")
69
+ logger.debug(f"Sending classifier config to: {config_url}")
70
+ response = requests.post(config_url, json=payload, timeout=self.timeout)
71
+ response.raise_for_status()
72
+ logger.info("Remote classifier initialized successfully.")
73
+ except Exception as e:
74
+ logger.warning(f"Failed to initialize remote classifier: {e}")
75
+
76
  def train(
77
  self,
78
  training_data: TrainingData,
 
88
  confidence: float = 0.0
89
 
90
  if text:
91
+ payload: Dict[str, Any] = {"message": text}
 
 
 
 
 
 
92
  try:
93
  resp = requests.post(self.url, json=payload, timeout=self.timeout)
94
  resp.raise_for_status()