File size: 3,919 Bytes
92d7f1e
 
1f11e25
2ddd1e5
62e947b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92d7f1e
 
62e947b
 
 
92d7f1e
 
 
 
 
2ddd1e5
62e947b
92d7f1e
 
 
2ddd1e5
 
 
92d7f1e
 
 
 
 
8b5c657
92d7f1e
2ddd1e5
 
 
 
8b5c657
2ddd1e5
 
 
8b5c657
2ddd1e5
 
 
 
 
92d7f1e
 
 
 
8b5c657
92d7f1e
cd81b99
 
 
37a110b
 
 
 
 
 
62e947b
 
37a110b
 
 
 
 
 
 
 
 
92d7f1e
 
 
 
1f11e25
 
 
 
 
 
 
 
 
 
 
 
 
8b5c657
1f11e25
 
 
8b5c657
1f11e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import difflib
import requests
import os
import json

FIREBASE_URL = os.getenv("FIREBASE_URL")

def fetch_from_firebase(model_id):
    response = requests.get(f"{FIREBASE_URL}/model_structures/{model_id}.json")
    if response.status_code == 200:
        return response.json()
    return None

def save_to_firebase(model_id, structure):
    response = requests.put(f"{FIREBASE_URL}/model_structures/{model_id}.json", data=json.dumps(structure))
    return response.status_code == 200

def get_model_structure(model_id):
    structure = fetch_from_firebase(model_id)
    if structure:
        return structure
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
    )
    structure = {k: str(v.shape) for k, v in model.state_dict().items()}
    save_to_firebase(model_id, structure)
    return structure

def compare_structures(struct1, struct2):
    struct1_lines = [f"{k}: {v}" for k, v in struct1.items()]
    struct2_lines = [f"{k}: {v}" for k, v in struct2.items()]
    diff = difflib.ndiff(struct1_lines, struct2_lines)
    return diff

def display_diff(diff):
    left_lines = []
    right_lines = []
    diff_found = False
    
    for line in diff:
        if line.startswith('- '):
            left_lines.append(f'<span style="background-color: #ffdddd;">{line[2:]}</span>')
            right_lines.append('')
            diff_found = True
        elif line.startswith('+ '):
            right_lines.append(f'<span style="background-color: #ddffdd;">{line[2:]}</span>')
            left_lines.append('')
            diff_found = True
        elif line.startswith('  '):
            left_lines.append(line[2:])
            right_lines.append(line[2:])
        else:
            pass
    
    left_html = "<br>".join(left_lines)
    right_html = "<br>".join(right_lines)
    
    return left_html, right_html, diff_found

# Set Streamlit page configuration to wide mode
st.set_page_config(layout="wide")

# Apply custom CSS for wider layout
st.markdown(
    """
    <style>
    .reportview-container .main .block-container {
        max-width: 100%;
        padding-left: 10%;
        padding-right: 10%;
    }
    .stMarkdown {
        white-space: pre-wrap;
    }
    </style>
    """,
    unsafe_allow_html=True
)

st.title("Model Structure Comparison Tool")
model_id1 = st.text_input("Enter the first HuggingFace Model ID")
model_id2 = st.text_input("Enter the second HuggingFace Model ID")

if st.button("Compare Models"):
    if model_id1 and model_id2:
        struct1 = get_model_structure(model_id1)
        struct2 = get_model_structure(model_id2)
        
        diff = compare_structures(struct1, struct2)
        left_html, right_html, diff_found = display_diff(diff)
        
        st.write("### Comparison Result")
        if not diff_found:
            st.success("The model structures are identical.")
        
        col1, col2 = st.columns([1.5, 1.5])  # Adjust the ratio to make columns wider

        with col1:
            st.write("### Model 1")
            st.markdown(left_html, unsafe_allow_html=True)

        with col2:
            st.write("### Model 2")
            st.markdown(right_html, unsafe_allow_html=True)
            
        # Tokenizer verification
        with st.spinner('Loading tokenizers...'):
            try:
                tokenizer1 = AutoTokenizer.from_pretrained(model_id1)
                tokenizer2 = AutoTokenizer.from_pretrained(model_id2)
                st.write(f"**{model_id1} Tokenizer Vocab Size**: {tokenizer1.vocab_size}")
                st.write(f"**{model_id2} Tokenizer Vocab Size**: {tokenizer2.vocab_size}")
            except Exception as e:
                st.error(f"Error loading tokenizers: {e}")
    else:
        st.error("Please enter both model IDs.")