File size: 4,369 Bytes
21795f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b589f5
21795f4
 
 
 
4b589f5
21795f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dec27fc
 
21795f4
 
 
dec27fc
21795f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7f27b4
21795f4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import spaces
import os
import sys
import subprocess

def install_packages():
    subprocess.check_call([sys.executable, "-m", "pip", "install", "unsloth-zoo"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "--no-deps", "git+https://github.com/unslothai/unsloth.git"])

try:
    install_packages()
except Exception as e:
    print(f"Failed to install packages: {e}")

import warnings
import torch

from transformers import TextStreamer
import gradio as gr
from huggingface_hub import login
from PIL import Image

warnings.filterwarnings('ignore')

model = None
tokenizer = None

if 'HUGGING_FACE_HUB_TOKEN' in os.environ:
    print("กำลังเข้าสู่ระบบ Hugging Face Hub...")
    login(token=os.environ['HUGGING_FACE_HUB_TOKEN'])
else:
    print("คำเตือน: ไม่พบ HUGGING_FACE_HUB_TOKEN")

###@spaces.GPU
def load_model():
    global model
    print("กำลังโหลดโมเดล...")
    try:
        from transformers import AutoModelForVision2Seq
        print("กำลังโหลดโมเดล fine-tuned...")
        
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = AutoModelForVision2Seq.from_pretrained(
            "0llheaven/Llama-3.2-11B-Vision-Radiology-mini",
            load_in_4bit = True,
            device_map=device,
            torch_dtype = torch.float16
        )

        print("โหลดโมเดลสำเร็จ!")
        return True
        
    except Exception as e:
        print(f"เกิดข้อผิดพลาดในการโหลดโมเดล: {str(e)}")
        import traceback
        traceback.print_exc()
        return False

@spaces.GPU(duration=120)
def process_image(image):
    global model

    ### โหลด tokenizer จาก base model
    from unsloth import FastVisionModel

    FastVisionModel.for_inference(model) ###ลองแก้ไขปัญหา torch

    from transformers import AutoTokenizer
    print("กำลังโหลด tokenizer...")
    base_model, tokenizer = FastVisionModel.from_pretrained(
        "unsloth/Llama-3.2-11B-Vision-Instruct",
        use_gradient_checkpointing = "unsloth",
        ### device_map="auto"  ### เพิ่มตรงนี้
    )
    
    print("\nใน process_image():")
    print("Type of model:", type(model))
    print("A. Type of tokenizer:", type(tokenizer))
    if tokenizer is not None:
        print("B. Available methods:", dir(tokenizer))
    
    if image is None:
        return "กรุณาอัพโหลดรูปภาพ"
    
    try:
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)

        print("0. Image info:", type(image), image.size)  # เพิ่ม debug ข้อมูลรูปภาพ
        instruction = "You are an expert radiographer. Describe accurately what you see in this image."
        messages = [
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": instruction}
            ]}
        ]

        print("1. Messages:", messages)  

        print("2. Tokenizer type:", type(tokenizer))
        input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
        print("3. Chat template success:", input_text[:100])
        inputs = tokenizer(
            image,
            input_text,
            add_special_tokens=False,
            return_tensors="pt",
        ).to("cuda")
        print("3. Tokenizer inputs:", inputs.keys())  # Debug 3

        text_streamer = TextStreamer(tokenizer, skip_prompt=True)
        outputs = model.generate(
            **inputs, 
            streamer=text_streamer,
            max_new_tokens=256,
            use_cache=True,
            temperature=1.5,
            min_p=0.1
        )
        
        return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        
    except Exception as e:
        return f"เกิดข้อผิดพลาด: {str(e)}"

if load_model():
    demo = gr.Interface(
        fn=process_image,
        inputs=gr.Image(type="pil", label="Upload Image"),
        outputs=gr.Textbox(label="Generated Caption"),
        title="Medical Vision Analysis"
    )
    
    if __name__ == "__main__":
        demo.launch()