|
from llama_cpp import Llama |
|
import requests |
|
import json |
|
from llm_config import get_llm_config |
|
|
|
class LLMWrapper: |
|
def __init__(self): |
|
self.llm_config = get_llm_config() |
|
self.llm_type = self.llm_config.get('llm_type', 'llama_cpp') |
|
if self.llm_type == 'llama_cpp': |
|
self.llm = self._initialize_llama_cpp() |
|
elif self.llm_type == 'ollama': |
|
self.base_url = self.llm_config.get('base_url', 'http://localhost:11434') |
|
self.model_name = self.llm_config.get('model_name', 'your_model_name') |
|
else: |
|
raise ValueError(f"Unsupported LLM type: {self.llm_type}") |
|
|
|
def _initialize_llama_cpp(self): |
|
if self.llm_config.get('model_path') is None: |
|
return Llama.from_pretrained( |
|
repo_id="Tien203/llama.cpp", |
|
filename="Llama-2-7b-hf-q4_0.gguf", |
|
) |
|
else: |
|
return Llama( |
|
model_path=self.llm_config.get('model_path'), |
|
n_ctx=self.llm_config.get('n_ctx', 2048), |
|
n_gpu_layers=self.llm_config.get('n_gpu_layers', 0), |
|
n_threads=self.llm_config.get('n_threads', 8), |
|
verbose=False |
|
) |
|
|
|
def generate(self, prompt, **kwargs): |
|
if self.llm_type == 'llama_cpp': |
|
llama_kwargs = self._prepare_llama_kwargs(kwargs) |
|
response = self.llm(prompt, **llama_kwargs) |
|
return response['choices'][0]['text'].strip() |
|
elif self.llm_type == 'ollama': |
|
return self._ollama_generate(prompt, **kwargs) |
|
else: |
|
raise ValueError(f"Unsupported LLM type: {self.llm_type}") |
|
|
|
def _ollama_generate(self, prompt, **kwargs): |
|
url = f"{self.base_url}/api/generate" |
|
data = { |
|
'model': self.model_name, |
|
'prompt': prompt, |
|
'options': { |
|
'temperature': kwargs.get('temperature', self.llm_config.get('temperature', 0.7)), |
|
'top_p': kwargs.get('top_p', self.llm_config.get('top_p', 0.9)), |
|
'stop': kwargs.get('stop', self.llm_config.get('stop', [])), |
|
'num_predict': kwargs.get('max_tokens', self.llm_config.get('max_tokens', 1024)), |
|
} |
|
} |
|
response = requests.post(url, json=data, stream=True) |
|
if response.status_code != 200: |
|
raise Exception(f"Ollama API request failed with status {response.status_code}: {response.text}") |
|
text = ''.join(json.loads(line)['response'] for line in response.iter_lines() if line) |
|
return text.strip() |
|
|
|
def _prepare_llama_kwargs(self, kwargs): |
|
llama_kwargs = { |
|
'max_tokens': kwargs.get('max_tokens', self.llm_config.get('max_tokens', 1024)), |
|
'temperature': kwargs.get('temperature', self.llm_config.get('temperature', 0.7)), |
|
'top_p': kwargs.get('top_p', self.llm_config.get('top_p', 0.9)), |
|
'stop': kwargs.get('stop', self.llm_config.get('stop', [])), |
|
'echo': False, |
|
} |
|
return llama_kwargs |
|
|