|
from vllm import LLM |
|
from vllm.sampling_params import SamplingParams |
|
import base64 |
|
|
|
|
|
def encode_image(image_path: str): |
|
with open(image_path, "rb") as image_file: |
|
return base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
|
|
class Pixtral: |
|
def __init__(self, max_model_len=4096, max_tokens=2048, gpu_memory_utilization=0.65, temperature=0.35): |
|
self.model_name = "mistralai/Pixtral-12B-2409" |
|
|
|
self.sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temperature) |
|
|
|
self.llm = LLM( |
|
model=self.model_name, |
|
tokenizer_mode="mistral", |
|
gpu_memory_utilization=gpu_memory_utilization, |
|
load_format="mistral", |
|
config_format="mistral", |
|
max_model_len=max_model_len |
|
) |
|
|
|
def generate_message_from_image(self, prompt, image_path): |
|
base64_image = encode_image(image_path) |
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": prompt}, |
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} |
|
] |
|
}, |
|
] |
|
|
|
outputs = self.llm.chat(messages, sampling_params=self.sampling_params) |
|
print("OUTPUT") |
|
print(outputs[0].outputs[0].text) |
|
|
|
return outputs[0].outputs[0].text |
|
|
|
def generate_message(self, prompt): |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": prompt}, |
|
] |
|
}, |
|
] |
|
|
|
outputs = self.llm.chat(messages, sampling_params=self.sampling_params) |
|
|
|
return outputs[0].outputs[0].text |
|
|