MekkCyber commited on
Commit
47d6fc0
·
1 Parent(s): d619e33

add great readme

Browse files
Files changed (1) hide show
  1. app.py +56 -8
app.py CHANGED
@@ -1,15 +1,13 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModel, BitsAndBytesConfig
4
  import tempfile
5
  from huggingface_hub import HfApi
6
  from huggingface_hub import list_models
7
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
8
  from bitsandbytes.nn import Linear4bit
9
- from packaging import version
10
  import os
11
- from tqdm import tqdm
12
-
13
 
14
  def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str:
15
  # ^ expect a gr.OAuthProfile object as input to get the user's profile
@@ -42,11 +40,52 @@ def check_model_exists(
42
  def create_model_card(
43
  model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4
44
  ):
45
- model_card = f"""---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  base_model:
47
- - {model_name}
48
- ---
49
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # {model_name} (Quantized)
51
 
52
  ## Description
@@ -63,6 +102,13 @@ It's quantized using the BitsAndBytes library to 4-bit using the [bnb-my-repo](h
63
 
64
  """
65
 
 
 
 
 
 
 
 
66
  return model_card
67
 
68
 
@@ -138,6 +184,8 @@ def save_model(
138
 
139
  with tempfile.TemporaryDirectory() as tmpdirname:
140
  # Save model
 
 
141
  model.save_pretrained(
142
  tmpdirname, safe_serialization=True, use_auth_token=auth_token.token
143
  )
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModel, BitsAndBytesConfig, AutoTokenizer
4
  import tempfile
5
  from huggingface_hub import HfApi
6
  from huggingface_hub import list_models
7
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
8
  from bitsandbytes.nn import Linear4bit
 
9
  import os
10
+ from huggingface_hub import snapshot_download
 
11
 
12
  def hello(profile: gr.OAuthProfile | None, oauth_token: gr.OAuthToken | None) -> str:
13
  # ^ expect a gr.OAuthProfile object as input to get the user's profile
 
40
  def create_model_card(
41
  model_name, quant_type_4, double_quant_4, compute_type_4, quant_storage_4
42
  ):
43
+ # Try to download the original README
44
+ original_readme = ""
45
+ original_yaml_header = ""
46
+ try:
47
+ # Download the README.md file from the original model
48
+ model_path = snapshot_download(repo_id=model_name, allow_patterns=["README.md"], repo_type="model")
49
+ readme_path = os.path.join(model_path, "README.md")
50
+
51
+ if os.path.exists(readme_path):
52
+ with open(readme_path, 'r', encoding='utf-8') as f:
53
+ content = f.read()
54
+
55
+ if content.startswith('---'):
56
+ parts = content.split('---', 2)
57
+ if len(parts) >= 3:
58
+ original_yaml_header = parts[1]
59
+ original_readme = '---'.join(parts[2:])
60
+ else:
61
+ original_readme = content
62
+ else:
63
+ original_readme = content
64
+ except Exception as e:
65
+ print(f"Error reading original README: {str(e)}")
66
+ original_readme = ""
67
+
68
+ # Create new YAML header with base_model field
69
+ yaml_header = f"""---
70
  base_model:
71
+ - {model_name}"""
72
+
73
+ # Add any original YAML fields except base_model
74
+ if original_yaml_header:
75
+ skip_next_line = False
76
+ for line in original_yaml_header.strip().split('\n'):
77
+ if skip_next_line:
78
+ skip_next_line = False
79
+ continue
80
+ if line.strip().startswith('base_model:'):
81
+ skip_next_line = True
82
+ continue
83
+ yaml_header += f"\n{line}"
84
+ # Complete the YAML header
85
+ yaml_header += "\n---"
86
+
87
+ # Create the quantization info section
88
+ quant_info = f"""
89
  # {model_name} (Quantized)
90
 
91
  ## Description
 
102
 
103
  """
104
 
105
+ # Combine everything
106
+ model_card = yaml_header + quant_info
107
+
108
+ # Append original README content if available
109
+ if original_readme and not original_readme.isspace():
110
+ model_card += "\n\n# Original Model Information\n" + original_readme
111
+
112
  return model_card
113
 
114
 
 
184
 
185
  with tempfile.TemporaryDirectory() as tmpdirname:
186
  # Save model
187
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token.token)
188
+ tokenizer.save_pretrained(tmpdirname, safe_serialization=True, use_auth_token=auth_token.token)
189
  model.save_pretrained(
190
  tmpdirname, safe_serialization=True, use_auth_token=auth_token.token
191
  )