Spaces:
Sleeping
Sleeping
File size: 4,131 Bytes
490def8 26cdd43 47913ac 26cdd43 9565da9 26cdd43 fab1822 490def8 26cdd43 490def8 47913ac 490def8 26cdd43 490def8 26cdd43 47913ac 26cdd43 490def8 26cdd43 490def8 26cdd43 490def8 26cdd43 490def8 47913ac 26cdd43 47913ac 26cdd43 47913ac 26cdd43 47913ac 26cdd43 47913ac 26cdd43 47913ac 26cdd43 490def8 26cdd43 490def8 26cdd43 490def8 26cdd43 490def8 26cdd43 490def8 26cdd43 490def8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Import necessary libraries
import gradio as gr
import sys
from huggingface_hub import ModelCard, HfApi
import requests
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from collections import defaultdict
from networkx.drawing.nx_pydot import graphviz_layout
from io import BytesIO
from PIL import Image
# Define the model ID
MODEL_ID = "mlabonne/NeuralBeagle14-7B"
# Define a class to cache model cards
class CachedModelCard(ModelCard):
_cache = {}
@classmethod
def load(cls, model_id: str, **kwargs) -> "ModelCard":
if model_id not in cls._cache:
try:
cls._cache[model_id] = super().load(model_id, **kwargs)
except:
cls._cache[model_id] = None
return cls._cache[model_id]
# Function to get model names from a YAML file
def get_model_names_from_yaml(url):
model_tags = []
response = requests.get(url)
if response.status_code == 200:
model_tags.extend([item for item in response.content if '/' in str(item)])
return model_tags
# Function to get the color of the model based on its license
def get_license_color(model):
try:
card = CachedModelCard.load(model)
license = card.data.to_dict()['license'].lower()
permissive_licenses = ['mit', 'bsd', 'apache-2.0', 'openrail']
if any(perm_license in license for perm_license in permissive_licenses):
return 'lightgreen'
else:
return 'lightcoral'
except Exception as e:
return 'lightgray'
# Function to find model names in the family tree
def get_model_names(model, genealogy, found_models=None, visited_models=None):
if found_models is None:
found_models = set()
if visited_models is None:
visited_models = set()
if model in visited_models:
return found_models
visited_models.add(model)
try:
card = CachedModelCard.load(model)
card_dict = card.data.to_dict()
license = card_dict['license']
model_tags = []
if 'base_model' in card_dict:
model_tags = card_dict['base_model']
if 'tags' in card_dict and not model_tags:
tags = card_dict['tags']
model_tags = [model_name for model_name in tags if '/' in model_name]
if not model_tags:
model_tags.extend(get_model_names_from_yaml(f"https://huggingface.co/{model}/blob/main/merge.yml"))
if not model_tags:
model_tags.extend(get_model_names_from_yaml(f"https://huggingface.co/{model}/blob/main/mergekit_config.yml"))
if not isinstance(model_tags, list):
model_tags = [model_tags] if model_tags else []
found_models.add(model)
for model_tag in model_tags:
genealogy[model_tag].append(model)
get_model_names(model_tag, genealogy, found_models, visited_models)
except Exception as e:
pass
return found_models
# Function to create the family tree
def create_family_tree(start_model):
genealogy = defaultdict(list)
get_model_names(start_model, genealogy)
G = nx.DiGraph()
for parent, children in genealogy.items():
for child in children:
G.add_edge(parent, child)
max_depth = nx.dag_longest_path_length(G) + 1
max_width = max_width_of_tree(G) + 1
height = max(8, 1.6 * max_depth)
width = max(8, 6 * max_width)
plt.figure(figsize=(width, height))
pos = graphviz_layout(G, prog="dot")
node_colors = [get_license_color(node) for node in G.nodes()]
clear_output()
labels = {node: node.replace("/", "\n") for node in G.nodes()}
nx.draw(G, pos, labels=labels, with_labels=True, node_color=node_colors, font_size=12, node_size=8_000, edge_color='black')
legend_elements = [
Patch(facecolor='lightgreen', label='Permissive'),
Patch(facecolor='lightcoral', label='Noncommercial'),
Patch(facecolor='lightgray', label='Unknown')
]
plt.legend(handles=legend_elements, loc='upper left')
plt.title(f"{start_model}'s Family Tree", fontsize=20)
plt.show()
create_family_tree(MODEL_ID)
|