File size: 4,461 Bytes
2a60823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import spaces
import os
import sys
import subprocess

def install_packages():
    subprocess.check_call([sys.executable, "-m", "pip", "install", "unsloth-zoo"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "git+https://github.com/unslothai/unsloth.git"])

try:
    install_packages()
except Exception as e:
    print(f"Failed to install packages: {e}")

# [แก้ 1] ย้าย environment variables มาไว้ก่อน imports
import os
os.environ['NVIDIA_VISIBLE_DEVICES'] = ''

import warnings
import torch
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.verbose = False

# [แก้ 2] ย้าย imports มาไว้ที่ module level
from unsloth import FastVisionModel
from transformers import AutoModelForVision2Seq
from transformers import TextStreamer
import gradio as gr
from huggingface_hub import login
from PIL import Image

warnings.filterwarnings('ignore')

if 'HUGGING_FACE_HUB_TOKEN' in os.environ:
    print("กำลังเข้าสู่ระบบ Hugging Face Hub...")
    login(token=os.environ['HUGGING_FACE_HUB_TOKEN'])
else:
    print("คำเตือน: ไม่พบ HUGGING_FACE_HUB_TOKEN")

# [แก้ 3] เพิ่ม @spaces.GPU decorator
@spaces.GPU
def model_context():
    _tokenizer = None
    _model = None

    def init_models():
        nonlocal _tokenizer, _model
        try:
            print("กำลังโหลด tokenizer...")
            # [แก้ 4] ลบ imports ออกจาก function
            base_model, _tokenizer = FastVisionModel.from_pretrained(
                "unsloth/Llama-3.2-11B-Vision-Instruct",
                use_gradient_checkpointing = "unsloth"
            )
            print("โหลด tokenizer สำเร็จ")

            print("กำลังโหลดโมเดล fine-tuned...")
            # [แก้ 5] ลบ import ออกจาก function
            _model = AutoModelForVision2Seq.from_pretrained(
                "Aekanun/Llama-3.2-11B-Vision-Instruct-XRay",
                load_in_4bit=True,
                torch_dtype=torch.float16
            ).to('cuda')
            FastVisionModel.for_inference(_model)
            print("โหลดโมเดลสำเร็จ!")
            return True
        except Exception as e:
            print(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}")
            return False

    def decorator(func):
        def wrapper(*args, **kwargs):
            return func(_model, _tokenizer, *args, **kwargs)
        return wrapper

    return init_models, decorator

init_models, model_decorator = model_context()

@model_decorator
@spaces.GPU(duration=30)
def process_image(_model, _tokenizer, image):
    if image is None:
        return "กรุณาอัพโหลดรูปภาพ"
    
    try:
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)

        instruction = "You are an expert radiographer. Describe accurately what you see in this image."
        messages = [
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": instruction}
            ]}
        ]

        input_text = _tokenizer.apply_chat_template(messages, add_generation_prompt=True)
        inputs = _tokenizer(
            image,
            input_text,
            add_special_tokens=False,
            return_tensors="pt",
        ).to("cuda")

        text_streamer = TextStreamer(_tokenizer, skip_prompt=True)
        outputs = _model.generate(
            **inputs,
            streamer=text_streamer,
            max_new_tokens=128,
            use_cache=True,
            temperature=1.5,
            min_p=0.1
        )
        
        return _tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        
    except Exception as e:
        return f"เกิดข้อผิดพลาด: {str(e)}"

print("กำลังเริ่มต้นแอปพลิเคชัน...")
if init_models():
    demo = gr.Interface(
        fn=process_image,
        inputs=gr.Image(type="pil"),
        outputs=gr.Textbox(),
        title="Medical Vision Analysis"
    )
    
    if __name__ == "__main__":
        demo.launch()
else:
    print("ไม่สามารถเริ่มต้นแอปพลิเคชันได้")