File size: 3,056 Bytes
0a69927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from llama_cpp import Llama
import requests
import json
from llm_config import get_llm_config

class LLMWrapper:
    def __init__(self):
        self.llm_config = get_llm_config()
        self.llm_type = self.llm_config.get('llm_type', 'llama_cpp')
        if self.llm_type == 'llama_cpp':
            self.llm = self._initialize_llama_cpp()
        elif self.llm_type == 'ollama':
            self.base_url = self.llm_config.get('base_url', 'http://localhost:11434')
            self.model_name = self.llm_config.get('model_name', 'your_model_name')
        else:
            raise ValueError(f"Unsupported LLM type: {self.llm_type}")

    def _initialize_llama_cpp(self):
        if self.llm_config.get('model_path') is None:
            return Llama.from_pretrained(
                repo_id="Tien203/llama.cpp",
                filename="Llama-2-7b-hf-q4_0.gguf",
            )
        else:
            return Llama(
                model_path=self.llm_config.get('model_path'),
                n_ctx=self.llm_config.get('n_ctx', 2048),
                n_gpu_layers=self.llm_config.get('n_gpu_layers', 0),
                n_threads=self.llm_config.get('n_threads', 8),
                verbose=False
            )

    def generate(self, prompt, **kwargs):
        if self.llm_type == 'llama_cpp':
            llama_kwargs = self._prepare_llama_kwargs(kwargs)
            response = self.llm(prompt, **llama_kwargs)
            return response['choices'][0]['text'].strip()
        elif self.llm_type == 'ollama':
            return self._ollama_generate(prompt, **kwargs)
        else:
            raise ValueError(f"Unsupported LLM type: {self.llm_type}")

    def _ollama_generate(self, prompt, **kwargs):
        url = f"{self.base_url}/api/generate"
        data = {
            'model': self.model_name,
            'prompt': prompt,
            'options': {
                'temperature': kwargs.get('temperature', self.llm_config.get('temperature', 0.7)),
                'top_p': kwargs.get('top_p', self.llm_config.get('top_p', 0.9)),
                'stop': kwargs.get('stop', self.llm_config.get('stop', [])),
                'num_predict': kwargs.get('max_tokens', self.llm_config.get('max_tokens', 1024)),
            }
        }
        response = requests.post(url, json=data, stream=True)
        if response.status_code != 200:
            raise Exception(f"Ollama API request failed with status {response.status_code}: {response.text}")
        text = ''.join(json.loads(line)['response'] for line in response.iter_lines() if line)
        return text.strip()

    def _prepare_llama_kwargs(self, kwargs):
        llama_kwargs = {
            'max_tokens': kwargs.get('max_tokens', self.llm_config.get('max_tokens', 1024)),
            'temperature': kwargs.get('temperature', self.llm_config.get('temperature', 0.7)),
            'top_p': kwargs.get('top_p', self.llm_config.get('top_p', 0.9)),
            'stop': kwargs.get('stop', self.llm_config.get('stop', [])),
            'echo': False,
        }
        return llama_kwargs