diff --git a/.gitattributes b/.gitattributes index 3e15efd719e55ebabe8454de9919507fc0c87be9..02238c50d3123b80f9c8c84856abb7607090a7ad 100644 --- a/.gitattributes +++ b/.gitattributes @@ -29,8 +29,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.tgz filter=lfs diff=lfs merge=lfs -text *.wasm filter=lfs diff=lfs merge=lfs -text *.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text -*zip filter=lfs diff=lfs merge=lfs -text -SuperResolutionAnimeDiffusion.zip filter=lfs diff=lfs merge=lfs -text -random_examples.zip filter=lfs diff=lfs merge=lfs -text +scenery.png filter=lfs diff=lfs merge=lfs -text +1boy.png filter=lfs diff=lfs merge=lfs -text +1girl.png filter=lfs diff=lfs merge=lfs -text +*.pk filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index 98c2d2047b6d29e3897102709df995a10b1bea8b..5768c3bc4732548ac508fda210374013b2b9ffbb 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ integrated_datasets/ *.state_dict *.config *.args +*.zip *.gz *.bin *.result.txt diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem new file mode 100644 index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3 --- /dev/null +++ b/.gradio/certificate.pem @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE----- diff --git a/1boy.png b/1boy.png new file mode 100644 index 0000000000000000000000000000000000000000..7c78d946d6f60d7c44f8920d7a04ee8ccb4c608d --- /dev/null +++ b/1boy.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a6489981ac72f7eff042054fbd8acaa718e20bbf26913250f0c325e4d1ca9230 +size 1851307 diff --git a/1girl.png b/1girl.png new file mode 100644 index 0000000000000000000000000000000000000000..dd75666ef466d2d4f4f926e65a5a29fa8caf30ef --- /dev/null +++ b/1girl.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7d13eec13f7f7a98c225c9f2340461ad1bbfbc6ef7b44ecb96eb0ca73d2723d +size 2105804 diff --git a/README.md b/README.md index bf81b0880c11e11ef954585437da0d1f20055482..4525736f44608983cd0c5b96efc7f34214d369b7 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,4 @@ ---- -title: Anything V3.0 -emoji: π -colorFrom: gray -colorTo: yellow -sdk: gradio -sdk_version: 3.10.1 -app_file: app.py -pinned: false ---- - -# If you have a GPU, try the [Stable Diffusion WebUI](https://github.com/yangheng95/stable-diffusion-webui) +# Super Resolution Anime Diffusion # [Online Web Demo](https://huggingface.co/spaces/yangheng/Super-Resolution-Anime-Diffusion) diff --git a/SuperResolutionAnimeDiffusion.zip b/SuperResolutionAnimeDiffusion.zip deleted file mode 100644 index c03f1b483cbe045b1f5b75730fe25a4f8ad308db..0000000000000000000000000000000000000000 --- a/SuperResolutionAnimeDiffusion.zip +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ef0a17ab5723e783a679cab83a05411ab09bb04886a629a5181bb18df0175c76 -size 256633269 diff --git a/Waifu2x/model_check_points/CRAN_V2/CARN_adam_checkpoint.pt b/Waifu2x/model_check_points/CRAN_V2/CARN_adam_checkpoint.pt new file mode 100644 index 0000000000000000000000000000000000000000..4a33a519ac642d453d1fac50c8226dd720fae2e5 --- /dev/null +++ b/Waifu2x/model_check_points/CRAN_V2/CARN_adam_checkpoint.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:292f2be9ea173861e4a7f6cf580f04fe9a1fc6c78fdac6f182cbc051ea50791e +size 31734614 diff --git a/Waifu2x/model_check_points/CRAN_V2/CARN_scheduler_last_iter.pt b/Waifu2x/model_check_points/CRAN_V2/CARN_scheduler_last_iter.pt new file mode 100644 index 0000000000000000000000000000000000000000..986b64ceae3ba605c7c92083f3f7eeb0fad3de24 --- /dev/null +++ b/Waifu2x/model_check_points/CRAN_V2/CARN_scheduler_last_iter.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ba2302e523d32bfeb9b542a9dc6aa5ecdb45babc793892153245d6c69ae23433 +size 151 diff --git a/Waifu2x/model_check_points/CRAN_V2/CRAN_V2_02_28_2019.pt b/Waifu2x/model_check_points/CRAN_V2/CRAN_V2_02_28_2019.pt new file mode 100644 index 0000000000000000000000000000000000000000..c169a54565bed42df21002ada73ad28683b80b53 --- /dev/null +++ b/Waifu2x/model_check_points/CRAN_V2/CRAN_V2_02_28_2019.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b74e163d829f6f587e3fdb0b645342e494416accb1962cf0973354de5ec157ea +size 49895595 diff --git a/Waifu2x/model_check_points/CRAN_V2/ReadME.md b/Waifu2x/model_check_points/CRAN_V2/ReadME.md index e432a1c7bd5de45f086ba1fd6b06ca712a1d806b..e4d8946f40420f35c109506dfc14f67bfb1f3eab 100644 --- a/Waifu2x/model_check_points/CRAN_V2/ReadME.md +++ b/Waifu2x/model_check_points/CRAN_V2/ReadME.md @@ -1,34 +1,41 @@ -# Resume & Use Model Check Points +# Model Specifications -This folder contains check points for models and their weights. They are generated from [PyTorch's pickle](https://pytorch.org/docs/master/notes/serialization.html). -Model specifications are in each folder's ReadME. - -Pickle names with "model" contain the entire models, and they can be used as an freeze module by calling the "forward_checkpoint" function to generate images. - -Example: ```python -import torch -# No need to reconstruct the model -model = torch.load("./DCSCN/DCSCN_model_387epos_L12_noise_1.pt") -x = torch.randn((1,3,10,10)), torch.randn((1,3,20,20)) -out = model.forward_checkpoint(a) -``` +model_cran_v2 = CARN_V2(color_channels=3, mid_channels=64, conv=nn.Conv2d, + single_conv_size=3, single_conv_group=1, + scale=2, activation=nn.LeakyReLU(0.1), + SEBlock=True, repeat_blocks=3, atrous=(1, 1, 1)) + +model_cran_v2 = network_to_half(model_cran_v2) +checkpoint = "CARN_model_checkpoint.pt" +model_cran_v2.load_state_dict(torch.load(checkpoint, 'cpu')) +model_cran_v2 = model_cran_v2.float() # if use cpu -Pickle names with "weights" are model weights, and they are named dictionaries. +```` -Example: -```python -model = DCSCN(*) # the setting must be the same to load check points weights. -model.load_state_dict(torch.load("./DCSCN/DCSCN_weights_387epos_L12_noise_1.pt")) -# then you can resume the model training -``` +To use pre-trained model for training -Model check poins in Upconv_7 and vgg_7 are from [waifu2x's repo](https://github.com/nagadomi/waifu2x/tree/master/models). To load weights into a model, please use ```load_pre_train_weights``` function. - -Example: ```python -model = UpConv_7() -model.load_pre_train_weights(json_file=...) -# then the model is ready to use -``` + +model = CARN_V2(color_channels=3, mid_channels=64, conv=nn.Conv2d, + single_conv_size=3, single_conv_group=1, + scale=2, activation=nn.LeakyReLU(0.1), + SEBlock=True, repeat_blocks=3, atrous=(1, 1, 1)) + +model = network_to_half(model) +model = model.cuda() +model.load_state_dict(torch.load("CARN_model_checkpoint.pt")) + +learning_rate = 1e-4 +weight_decay = 1e-6 +optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, amsgrad=True) +optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0, verbose=False) +optimizer.load_state_dict(torch.load("CARN_adam_checkpoint.pt")) + +last_iter = torch.load("CARN_scheduler_last_iter") # -1 if start from new +scheduler = CyclicLR(optimizer.optimizer, base_lr=1e-4, max_lr=4e-4, + step_size=3 * total_batch, mode="triangular", + last_batch_iteration=last_iter) + +``` \ No newline at end of file diff --git a/Waifu2x/model_check_points/CRAN_V2/test_loss.pt b/Waifu2x/model_check_points/CRAN_V2/test_loss.pt new file mode 100644 index 0000000000000000000000000000000000000000..81472c8a38c74458858a775f9e111850a29dc077 --- /dev/null +++ b/Waifu2x/model_check_points/CRAN_V2/test_loss.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93f644a6a3f6636035980855f56ef3dbc8784679371b06b81e0e4d06067c142d +size 43507 diff --git a/Waifu2x/model_check_points/CRAN_V2/test_psnr.pt b/Waifu2x/model_check_points/CRAN_V2/test_psnr.pt new file mode 100644 index 0000000000000000000000000000000000000000..2137d20b91a68d2d42383e262ecba2d04375410e --- /dev/null +++ b/Waifu2x/model_check_points/CRAN_V2/test_psnr.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae8f8d1a3d175e76dcbcdcf0cede898e8f2cf169f3eec14eeb75a4e19d8e2d6b +size 42563 diff --git a/Waifu2x/model_check_points/CRAN_V2/test_ssim.pt b/Waifu2x/model_check_points/CRAN_V2/test_ssim.pt new file mode 100644 index 0000000000000000000000000000000000000000..fe592a88e77983e336ff530ef524193daf1faa97 --- /dev/null +++ b/Waifu2x/model_check_points/CRAN_V2/test_ssim.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:763ff936f536b12b37b351c09f3c1290fb2188399aea3d9ce3cf069bd0d135e7 +size 43515 diff --git a/Waifu2x/model_check_points/CRAN_V2/train_loss.pt b/Waifu2x/model_check_points/CRAN_V2/train_loss.pt new file mode 100644 index 0000000000000000000000000000000000000000..18e1b7bbd8ccdbb1acbc59ec7cac77147033a845 --- /dev/null +++ b/Waifu2x/model_check_points/CRAN_V2/train_loss.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85a86e94cd689adff04c4b22bf2534d17aa52af5e7309a82bc2a4f5c6c144900 +size 15564175 diff --git a/Waifu2x/model_check_points/CRAN_V2/train_psnr.pt b/Waifu2x/model_check_points/CRAN_V2/train_psnr.pt new file mode 100644 index 0000000000000000000000000000000000000000..b5be8ed37f75bfad954aa2d8410b7e7f92622a43 --- /dev/null +++ b/Waifu2x/model_check_points/CRAN_V2/train_psnr.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d1e88646b74a054ddf20ba41368a01162e35d9c88ac72f392a6ba08a5c7ef3b +size 15564175 diff --git a/Waifu2x/model_check_points/CRAN_V2/train_ssim.pt b/Waifu2x/model_check_points/CRAN_V2/train_ssim.pt new file mode 100644 index 0000000000000000000000000000000000000000..2c62e228000cf2b9dae876319b4db3670f8c8f39 --- /dev/null +++ b/Waifu2x/model_check_points/CRAN_V2/train_ssim.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6b8da8bc73f64997c5b2d15d6161b11dbd172258a62c88572c032feb73bd022b +size 15564175 diff --git a/Waifu2x/model_check_points/DCSCN/DCSCN_model_387epos_L12_noise_1.pt b/Waifu2x/model_check_points/DCSCN/DCSCN_model_387epos_L12_noise_1.pt new file mode 100644 index 0000000000000000000000000000000000000000..6b5ef28281d6b943af95948727e78ae4a8e0f21c --- /dev/null +++ b/Waifu2x/model_check_points/DCSCN/DCSCN_model_387epos_L12_noise_1.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7aaf293584618b446868910a173de4eed2e054f33e325f9c93cabacb0937e6d5 +size 7585347 diff --git a/Waifu2x/model_check_points/DCSCN/DCSCN_weights_387epos_L12_noise_1.pt b/Waifu2x/model_check_points/DCSCN/DCSCN_weights_387epos_L12_noise_1.pt new file mode 100644 index 0000000000000000000000000000000000000000..28e6ab54e90e64e522e9aa86c73e45f2b2a46adf --- /dev/null +++ b/Waifu2x/model_check_points/DCSCN/DCSCN_weights_387epos_L12_noise_1.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8faddf6e3bf6acf688642a99da23d5626a6173c1eb92d2cdd26a5d3dd6a73da4 +size 7568033 diff --git a/Waifu2x/model_check_points/DCSCN/DCSCN_weights_45epos_L8_noise_1.pt b/Waifu2x/model_check_points/DCSCN/DCSCN_weights_45epos_L8_noise_1.pt new file mode 100644 index 0000000000000000000000000000000000000000..91a610829f8da16fcad80af6799532ab2be7d241 --- /dev/null +++ b/Waifu2x/model_check_points/DCSCN/DCSCN_weights_45epos_L8_noise_1.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b8c7b3c6c4bc1b8d48186352f9d74b685210ca8a372a06bd8718c2d20e0769e +size 9746842 diff --git a/Waifu2x/model_check_points/DCSCN/ReadME.md b/Waifu2x/model_check_points/DCSCN/ReadME.md new file mode 100644 index 0000000000000000000000000000000000000000..4f5272d362409bf0831c8c6cef263676e075f581 --- /dev/null +++ b/Waifu2x/model_check_points/DCSCN/ReadME.md @@ -0,0 +1,13 @@ +# Model Specifications + +## 12 Layers Model + +```python +model = DCSCN(color_channel=3, + up_scale=2, + feature_layers=12, + first_feature_filters=196, + last_feature_filters=48, + reconstruction_filters=64, + up_sampler_filters=32) +```` diff --git a/Waifu2x/model_check_points/ESPCN/ESPCN_7_weights_14epos.pk b/Waifu2x/model_check_points/ESPCN/ESPCN_7_weights_14epos.pk new file mode 100644 index 0000000000000000000000000000000000000000..c598b84f77d5e41790362811fa04fbe93acc0735 --- /dev/null +++ b/Waifu2x/model_check_points/ESPCN/ESPCN_7_weights_14epos.pk @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60088b9b7865535dae982af5f6ca2e361ecb6ce9ee1cc43c8ce4f6b1e1a4abe7 +size 5388762 diff --git a/Waifu2x/model_check_points/Upconv_7/anime.7z b/Waifu2x/model_check_points/Upconv_7/anime.7z new file mode 100644 index 0000000000000000000000000000000000000000..98bf146ddfee50f4696aab5198420bb642b9ac6f --- /dev/null +++ b/Waifu2x/model_check_points/Upconv_7/anime.7z @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0b4514f546498bf8966dd74e806d2f4034573809f91ca02659710d666235266d +size 19867323 diff --git a/Waifu2x/model_check_points/Upconv_7/photo.7z b/Waifu2x/model_check_points/Upconv_7/photo.7z new file mode 100644 index 0000000000000000000000000000000000000000..b12c49056cf635ca41a0c4790960e32ed12551f8 --- /dev/null +++ b/Waifu2x/model_check_points/Upconv_7/photo.7z @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7a173165da9b2b101f8964c55ce2472b3ce15a7a6f742804037e5c7a5a321ae +size 19872894 diff --git a/Waifu2x/model_check_points/vgg_7/art.7z b/Waifu2x/model_check_points/vgg_7/art.7z new file mode 100644 index 0000000000000000000000000000000000000000..20f749298928a42536e0ef6f35650f94fb232d89 --- /dev/null +++ b/Waifu2x/model_check_points/vgg_7/art.7z @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae5e88101e4b5591e795ffa8661b36c4986bf9ce9e762a9e21d9f268a2a8effe +size 10456728 diff --git a/Waifu2x/model_check_points/vgg_7/art_y.7z b/Waifu2x/model_check_points/vgg_7/art_y.7z new file mode 100644 index 0000000000000000000000000000000000000000..a9ad64846a2dfc80bfbf75f69a1c804a0e974299 --- /dev/null +++ b/Waifu2x/model_check_points/vgg_7/art_y.7z @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f24fcbf0e0d2a9d9242e3188fe8fb3de82d77da82a5228664be4dc2a69aef7a +size 8281792 diff --git a/Waifu2x/model_check_points/vgg_7/photo.7z b/Waifu2x/model_check_points/vgg_7/photo.7z new file mode 100644 index 0000000000000000000000000000000000000000..5082ed78c86b441f213e3c60944dceef88ca03d9 --- /dev/null +++ b/Waifu2x/model_check_points/vgg_7/photo.7z @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a96d475054665d050c370f3786097690920523c71231ef276ab2c7d011d305b1 +size 10459233 diff --git a/Waifu2x/model_check_points/vgg_7/ukbench.7z b/Waifu2x/model_check_points/vgg_7/ukbench.7z new file mode 100644 index 0000000000000000000000000000000000000000..3441c4114f64bc6849a8c0d3cf2ece1a9fc64943 --- /dev/null +++ b/Waifu2x/model_check_points/vgg_7/ukbench.7z @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:05f6e10f467b10ab66a9a4d41443a7f280e67925eb50c96fc8e43287ce56e205 +size 2088088 diff --git a/app.py b/app.py index 1c732340756e7706fe399c68d6908469fe489fcb..b2972485ca159a98d583305f3c3a3044997b426b 100644 --- a/app.py +++ b/app.py @@ -1,20 +1,43 @@ +""" +Super Resolution Anime Diffusion - Enhanced WebUI + +This is an enhanced version of the original Super Resolution Anime Diffusion project by yangheng95. +The WebUI has been improved with modern Gradio API implementation, better user experience, +and comprehensive documentation. + +Key Contributions: +- Updated to use modern Gradio Blocks API for better interface organization +- Added tabbed interface for Text-to-Image, Image-to-Image, and Gallery views +- Improved error handling and user feedback with progress indicators +- Enhanced UI styling with custom CSS and responsive design +- Better parameter organization with collapsible accordions +- Real-time system information display + +Instructions: +1. Choose between Text-to-Image or Image-to-Image tabs +2. Select a model from the dropdown (or provide custom model path) +3. Enter your prompt and adjust parameters as needed +4. For Image-to-Image: upload a base image to transform +5. Configure super-resolution settings (method and scale factor) +6. Click Generate to create high-quality anime images with automatic upscaling + +Original Author: yangheng95 +Original Repository: https://github.com/yangheng95/SuperResolutionAnimeDiffusion +License: Creative ML Open RAIL-M +Enhanced WebUI by AI Assistant +""" + import os -import random +import sys import zipfile -import findfile +from typing import Optional, List, Tuple +from datetime import datetime +import time +import psutil + import PIL.Image import autocuda -from pyabsa.utils.pyabsa_utils import fprint - -try: - for z_file in findfile.find_cwd_files(and_key=['.zip'], - exclude_key=['.ignore', 'git', 'SuperResolutionAnimeDiffusion'], - recursive=10): - fprint(f"Extracting {z_file}...") - with zipfile.ZipFile(z_file, 'r') as zip_ref: - zip_ref.extractall(os.path.dirname(z_file)) -except Exception as e: - os.system('unzip random_examples.zip') +import findfile from diffusers import ( AutoencoderKL, @@ -27,59 +50,95 @@ import gradio as gr import torch from PIL import Image import utils -import datetime -import time -import psutil from Waifu2x.magnify import ImageMagnifier from RealESRGANv030.interface import realEsrgan -magnifier = ImageMagnifier() +sys.path.append(os.path.dirname(__file__)) # Ensure current directory is in path +os.environ["PYTHONPATH"] = os.path.dirname(__file__) + +# Application Configuration +APP_TITLE = "π¨ Super Resolution Anime Diffusion" +APP_DESCRIPTION = """ +Generate high-quality anime images with automatic super resolution enhancement. +Combines Stable Diffusion models with advanced upscaling techniques (RealESRGAN & Waifu2x). +""" +CONTRIBUTION_INFO = """ +### π€ Enhanced Features +This interface improves upon the original work with: +- **Modern UI**: Clean tabbed interface with Gradio Blocks +- **Better UX**: Progress tracking and real-time feedback +- **Enhanced Parameters**: Organized controls with descriptions +- **Gallery View**: Browse and manage generated images +- **Error Handling**: Comprehensive error reporting and recovery +""" + +INSTRUCTIONS = """ +### π How to Use +1. **Select Mode**: Choose Text-to-Image or Image-to-Image tab +2. **Pick Model**: Select from available models or use custom path +3. **Create Prompt**: Describe your desired image (use negative prompt to avoid elements) +4. **Upload Image**: For img2img mode, provide base image +5. **Adjust Settings**: Fine-tune resolution, steps, and guidance +6. **Set Upscaling**: Choose super-resolution method and scale +7. **Generate**: Click the generate button and wait for results! +""" + +COPYRIGHT_INFO = """ +**Original Author**: [yangheng95](https://github.com/yangheng95) | +**Repository**: [SuperResolutionAnimeDiffusion](https://github.com/yangheng95/SuperResolutionAnimeDiffusion) | +**License**: Creative ML Open RAIL-M | **Enhanced by**: AI Assistant +""" + +DEFAULT_NEGATIVE_PROMPT = "bad result, worst, random, invalid, inaccurate, imperfect, blurry, deformed, disfigured, mutation, mutated, ugly, out of focus, bad anatomy, text, error, extra digit, fewer digits, worst quality, low quality, normal quality, noise, jpeg artifact, compression artifact, signature, watermark, username, logo, low resolution, worst resolution, bad resolution, normal resolution, bad detail, bad details, bad lighting, bad shadow, bad shading, bad background, worst background" + +# Initialization +magnifier = ImageMagnifier() start_time = time.time() is_colab = utils.is_google_colab() - -CUDA_VISIBLE_DEVICES = "" device = autocuda.auto_cuda() - dtype = torch.float16 if device != "cpu" else torch.float32 - +# Extract zip files if needed +for z_file in findfile.find_cwd_files(and_key=['.zip'], exclude_key=['.ignore'], recursive=1): + try: + with zipfile.ZipFile(z_file, 'r') as zip_ref: + zip_ref.extractall() + except Exception as e: + print(f"Warning: Could not extract {z_file}: {e}") class Model: - def __init__(self, name, path="", prefix=""): + """Model configuration class""" + def __init__(self, name: str, path: str = "", prefix: str = ""): self.name = name self.path = path self.prefix = prefix self.pipe_t2i = None self.pipe_i2i = None - +# Model configurations models = [ - # Model("anything v3", "Linaqruf/anything-v3.0", "anything v3 style"), - Model("anything v5", "stablediffusionapi/anything-v5", "anything v5 style"), + Model("Anything v4.5", "xyn-ai/anything-v4.0", "anything v4.5 style"), ] -# Model("Spider-Verse", "nitrosocke/spider-verse-diffusion", "spiderverse style "), -# Model("Balloon Art", "Fictiverse/Stable_Diffusion_BalloonArt_Model", "BalloonArt "), -# Model("Elden Ring", "nitrosocke/elden-ring-diffusion", "elden ring style "), -# Model("Tron Legacy", "dallinmackay/Tron-Legacy-diffusion", "trnlgcy ") -# Model("PokΓ©mon", "lambdalabs/sd-pokemon-diffusers", ""), -# Model("Pony Diffusion", "AstraliteHeart/pony-diffusion", ""), -# Model("Robo Diffusion", "nousr/robo-diffusion", ""), - -scheduler = DPMSolverMultistepScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - num_train_timesteps=1000, - trained_betas=None, - predict_epsilon=True, - thresholding=False, - algorithm_type="dpmsolver++", - solver_type="midpoint", - solver_order=2, - # lower_order_final=True, -) +# Scheduler configuration +scheduler = DPMSolverMultistepScheduler.from_config({ + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "num_train_timesteps": 1000, + "trained_betas": None, + "prediction_type": "epsilon", + "thresholding": False, + "algorithm_type": "dpmsolver++", + "solver_type": "midpoint", + "solver_order": 2, + "use_karras_sigmas": False, + "timestep_spacing": "leading", + "steps_offset": 1 +}) + +# Global state custom_model = None if is_colab: models.insert(0, Model("Custom model")) @@ -88,177 +147,198 @@ if is_colab: last_mode = "txt2img" current_model = models[1] if is_colab else models[0] current_model_path = current_model.path +pipe = None -if is_colab: - pipe = StableDiffusionPipeline.from_pretrained( - current_model.path, - torch_dtype=dtype, - scheduler=scheduler, - safety_checker=lambda images, clip_input: (images, False), - ) +def initialize_models(): + """Initialize diffusion models with error handling""" + global pipe -else: # download all models - print(f"{datetime.datetime.now()} Downloading vae...") - vae = AutoencoderKL.from_pretrained( - current_model.path, subfolder="vae", torch_dtype=dtype - ) - for model in models: + if is_colab: try: - print(f"{datetime.datetime.now()} Downloading {model.name} model...") - unet = UNet2DConditionModel.from_pretrained( - model.path, subfolder="unet", torch_dtype=dtype - ) - model.pipe_t2i = StableDiffusionPipeline.from_pretrained( - model.path, - unet=unet, - vae=vae, - torch_dtype=dtype, - scheduler=scheduler, - safety_checker=None, - ) - model.pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained( - model.path, - unet=unet, - vae=vae, + pipe = StableDiffusionPipeline.from_pretrained( + current_model.path, torch_dtype=dtype, scheduler=scheduler, safety_checker=None, ) except Exception as e: - print( - f"{datetime.datetime.now()} Failed to load model " - + model.name - + ": " - + str(e) + print(f"Failed to initialize model: {e}") + return + else: + print(f"{datetime.now()} Loading models...") + try: + vae = AutoencoderKL.from_pretrained( + current_model.path, subfolder="vae", torch_dtype=dtype ) - models.remove(model) - pipe = models[0].pipe_t2i - -# model.pipe_i2i = torch.compile(model.pipe_i2i) -# model.pipe_t2i = torch.compile(model.pipe_t2i) -if torch.cuda.is_available(): - pipe = pipe.to(device) + for model in models[:]: + try: + print(f"Loading {model.name}...") + unet = UNet2DConditionModel.from_pretrained( + model.path, subfolder="unet", torch_dtype=dtype + ) + model.pipe_t2i = StableDiffusionPipeline.from_pretrained( + model.path, + unet=unet, + vae=vae, + torch_dtype=dtype, + scheduler=scheduler, + safety_checker=None, + ) + model.pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained( + model.path, + unet=unet, + vae=vae, + torch_dtype=dtype, + scheduler=scheduler, + safety_checker=None, + ) + print(f"β {model.name} loaded successfully") + except Exception as e: + print(f"β Failed to load {model.name}: {e}") + models.remove(model) -# device = "GPU π₯" if torch.cuda.is_available() else "CPU π₯Ά" - + if models: + pipe = models[0].pipe_t2i + except Exception as e: + print(f"Failed to initialize models: {e}") + return + + if torch.cuda.is_available() and pipe: + pipe = pipe.to(device) + +def get_system_info() -> str: + """Get system information""" + gpu_name = "CPU" + if torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name() + + memory = psutil.virtual_memory() + return f"π₯οΈ Device: {gpu_name} | πΎ RAM: {memory.available // (1024**3):.1f}GB" + +def error_str(error: Exception, title: str = "Error") -> str: + """Format error messages""" + return f"### β {title}\n```\n{str(error)}\n```" + +def custom_model_changed(path: str) -> str: + """Handle custom model path changes""" + if custom_model and path.strip(): + models[0].path = path.strip() + global current_model + current_model = models[0] + return "β Custom model path updated" + return "β Please enter a valid model path" + +def on_model_change(model_name: str) -> Tuple[gr.update, gr.update]: + """Handle model selection changes""" + selected_model = next((m for m in models if m.name == model_name), None) + + if selected_model and selected_model != models[0] if custom_model else True: + prefix_text = f'Prompt (automatically prefixed with "{selected_model.prefix}")' + is_custom = False + else: + prefix_text = "Enter prompt (remember to include model-specific prefix)" + is_custom = True -def error_str(error, title="Error"): return ( - f"""#### {title} - {error}""" - if error - else "" + gr.update(visible=is_custom), + gr.update(placeholder=prefix_text), ) - -def custom_model_changed(path): - models[0].path = path +def generate_image( + mode: str, + model_name: str, + prompt: str, + negative_prompt: str, + width: int, + height: int, + guidance_scale: float, + num_steps: int, + seed: int, + image: Optional[PIL.Image.Image], + strength: float, + scale_method: str, + scale_factor: int, + progress=gr.Progress() +) -> Tuple[Optional[PIL.Image.Image], str]: + """Main image generation function""" + + if progress: + progress(0, desc="Starting generation...") + + # Validation + if not prompt.strip(): + return None, "β Please enter a prompt" + + if mode == "img2img" and image is None: + return None, "β Please upload an image for Image-to-Image mode" + + # Find model global current_model - current_model = models[0] - + selected_model = next((m for m in models if m.name == model_name), None) + if not selected_model: + return None, error_str(ValueError(f"Model '{model_name}' not found")) -def on_model_change(model_name): - prefix = ( - 'Enter prompt. "' - + next((m.prefix for m in models if m.name == model_name), None) - + '" is prefixed automatically' - if model_name != models[0].name - else "Don't forget to use the custom model prefix in the prompt!" - ) + current_model = selected_model - return ( - gr.update(visible=model_name == models[0].name), - gr.update(placeholder=prefix), - ) + if progress: + progress(0.1, desc=f"Using {model_name}") - -def inference( - model_name, - prompt, - guidance, - steps, - width=512, - height=512, - seed=0, - img=None, - strength=0.5, - neg_prompt="", - scale="ESRGAN4x", - scale_factor=2, -): - fprint(psutil.virtual_memory()) # print memory usage - - fprint(f"Prompt: {prompt}") - global current_model - for model in models: - if model.name == model_name: - current_model = model - model_path = current_model.path - - generator = torch.Generator(device).manual_seed(seed) if seed != 0 else None + # Setup generator + if seed <= 0: + seed = torch.randint(0, 2**32-1, (1,)).item() + generator = torch.Generator(device).manual_seed(seed) try: - if img is not None: - return ( - img_to_img( - model_path, - prompt, - neg_prompt, - img, - strength, - guidance, - steps, - width, - height, - generator, - scale, - scale_factor, - ), - None, + if mode == "img2img": + result_image = img_to_img( + current_model.path, prompt, negative_prompt, image, strength, + guidance_scale, num_steps, width, height, generator, + scale_method, scale_factor, progress ) else: - return ( - txt_to_img( - model_path, - prompt, - neg_prompt, - guidance, - steps, - width, - height, - generator, - scale, - scale_factor, - ), - None, + result_image = txt_to_img( + current_model.path, prompt, negative_prompt, guidance_scale, + num_steps, width, height, generator, scale_method, scale_factor, + progress ) - except Exception as e: - return None, error_str(e) - # if img is not None: - # return img_to_img(model_path, prompt, neg_prompt, img, strength, guidance, steps, width, height, - # generator, scale, scale_factor), None - # else: - # return txt_to_img(model_path, prompt, neg_prompt, guidance, steps, width, height, generator, scale, scale_factor), None + if progress: + progress(1.0, desc="Complete!") + + # Save result + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + os.makedirs("imgs", exist_ok=True) + filename = f"imgs/result-{timestamp}.png" + result_image.save(filename) + + info = f"""### β Generation Complete +- **Mode**: {mode} +- **Model**: {model_name} +- **Resolution**: {result_image.size[0]}x{result_image.size[1]} +- **Scale**: {scale_factor}x ({scale_method}) +- **Seed**: {seed} +- **Saved**: {filename}""" + + return result_image, info + + except Exception as e: + print(f"Generation error: {e}") + return None, error_str(e, "Generation Failed") def txt_to_img( - model_path, - prompt, - neg_prompt, - guidance, - steps, - width, - height, - generator, - scale, - scale_factor, -): - print(f"{datetime.datetime.now()} txt_to_img, model: {current_model.name}") - - global last_mode - global pipe - global current_model_path + model_path: str, prompt: str, neg_prompt: str, guidance: float, + steps: int, width: int, height: int, generator, scale: str, + scale_factor: int, progress +) -> PIL.Image.Image: + """Text-to-image generation""" + + global last_mode, pipe, current_model_path + + if progress: + progress(0.2, desc="Loading pipeline...") + + # Load pipeline if needed if model_path != current_model_path or last_mode != "txt2img": current_model_path = model_path @@ -267,70 +347,63 @@ def txt_to_img( current_model_path, torch_dtype=dtype, scheduler=scheduler, - safety_checker=lambda images, clip_input: (images, False), + safety_checker=None, ) else: - # pipe = pipe.to("cpu") pipe = current_model.pipe_t2i if torch.cuda.is_available(): pipe = pipe.to(device) last_mode = "txt2img" - prompt = current_model.prefix + prompt + if progress: + progress(0.4, desc="Generating image...") + + # Add model prefix + full_prompt = f"{current_model.prefix}, {prompt}" if current_model.prefix else prompt + result = pipe( - prompt, + full_prompt, negative_prompt=neg_prompt, - # num_images_per_prompt=n_images, num_inference_steps=int(steps), guidance_scale=guidance, width=width, height=height, generator=generator, - ) + ).images[0] - # result.images[0] = magnifier.magnify(result.images[0], scale_factor=scale_factor) - # enhance resolution + if progress: + progress(0.7, desc="Applying super resolution...") + + # Apply super resolution if scale_factor > 1: - if scale == "ESRGAN4x": - fp32 = True if device == "cpu" else False - result.images[0] = realEsrgan( - input_dir=result.images[0], + if scale == "RealESRGAN": + fp32 = device == "cpu" + result = realEsrgan( + input_dir=result, suffix="", output_dir="imgs", fp32=fp32, outscale=scale_factor, )[0] - else: - result.images[0] = magnifier.magnify( - result.images[0], scale_factor=scale_factor - ) - # save image - result.images[0].save( - "imgs/result-{}.png".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) - ) - return replace_nsfw_images(result) + else: # Waifu2x + result = magnifier.magnify(result, scale_factor=scale_factor) + return result def img_to_img( - model_path, - prompt, - neg_prompt, - img, - strength, - guidance, - steps, - width, - height, - generator, - scale, - scale_factor, -): - fprint(f"{datetime.datetime.now()} img_to_img, model: {model_path}") - - global last_mode - global pipe - global current_model_path + model_path: str, prompt: str, neg_prompt: str, img: PIL.Image.Image, + strength: float, guidance: float, steps: int, width: int, height: int, + generator, scale: str, scale_factor: int, progress +) -> PIL.Image.Image: + """Image-to-image generation""" + + global last_mode, pipe, current_model_path + + if progress: + progress(0.2, desc="Loading pipeline...") + + # Load pipeline if needed if model_path != current_model_path or last_mode != "img2img": current_model_path = model_path @@ -339,263 +412,396 @@ def img_to_img( current_model_path, torch_dtype=dtype, scheduler=scheduler, - safety_checker=lambda images, clip_input: (images, False), + safety_checker=None, ) else: - # pipe = pipe.to("cpu") pipe = current_model.pipe_i2i if torch.cuda.is_available(): pipe = pipe.to(device) last_mode = "img2img" - prompt = current_model.prefix + prompt + # Resize input image + if progress: + progress(0.3, desc="Processing input image...") + ratio = min(height / img.height, width / img.width) img = img.resize((int(img.width * ratio), int(img.height * ratio)), Image.LANCZOS) + + # Add model prefix + full_prompt = f"{current_model.prefix}, {prompt}" if current_model.prefix else prompt + + if progress: + progress(0.4, desc="Transforming image...") + result = pipe( - prompt, + full_prompt, negative_prompt=neg_prompt, - # num_images_per_prompt=n_images, image=img, num_inference_steps=int(steps), strength=strength, guidance_scale=guidance, - # width=width, - # height=height, generator=generator, - ) + ).images[0] + + if progress: + progress(0.7, desc="Applying super resolution...") + + # Apply super resolution if scale_factor > 1: - if scale == "ESRGAN4x": - fp32 = True if device == "cpu" else False - result.images[0] = realEsrgan( - input_dir=result.images[0], + if scale == "RealESRGAN": + fp32 = device == "cpu" + result = realEsrgan( + input_dir=result, suffix="", output_dir="imgs", fp32=fp32, outscale=scale_factor, )[0] - else: - result.images[0] = magnifier.magnify( - result.images[0], scale_factor=scale_factor - ) - # save image - result.images[0].save( - "imgs/result-{}.png".format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) - ) - return replace_nsfw_images(result) + else: # Waifu2x + result = magnifier.magnify(result, scale_factor=scale_factor) + + return result + +def load_example_images() -> List[str]: + """Load example images for gallery""" + example_images = [] + for f_img in findfile.find_cwd_files(".png", recursive=2): + if "result-" in os.path.basename(f_img) or "random_examples" in f_img: + example_images.append(f_img) + return example_images[:12] # Limit examples + +# Custom CSS for styling +custom_css = """ +.gradio-container { + font-family: 'Segoe UI', system-ui, sans-serif; + max-width: 1400px; + margin: 0 auto; +} + +.header-section { + text-align: center; + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + color: white; + padding: 2rem; + border-radius: 15px; + margin-bottom: 2rem; +} + +.info-card { + background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); + color: white; + padding: 1.5rem; + border-radius: 10px; + margin: 1rem 0; +} + +.status-info { + background: #e8f5e8; + border-left: 4px solid #4CAF50; + padding: 1rem; + border-radius: 5px; + margin: 1rem 0; +} + +.generate-btn { + background: linear-gradient(45deg, #FF6B6B, #4ECDC4) !important; + border: none !important; + border-radius: 25px !important; + padding: 15px 30px !important; + font-size: 16px !important; + font-weight: bold !important; + color: white !important; + transition: all 0.3s ease !important; +} + +.generate-btn:hover { + transform: translateY(-2px) !important; + box-shadow: 0 10px 20px rgba(0,0,0,0.2) !important; +} +""" +def create_interface(): + """Create the Gradio interface""" -def replace_nsfw_images(results): - if is_colab: - return results.images[0] - if hasattr(results, "nsfw_content_detected") and results.nsfw_content_detected: - for i in range(len(results.images)): - if results.nsfw_content_detected[i]: - results.images[i] = Image.open("nsfw.png") - return results.images[0] + with gr.Blocks(title=APP_TITLE, css=custom_css) as demo: + # Header + with gr.Row(): + gr.HTML(f""" +
{APP_DESCRIPTION}
+constants.SAFETENSORS_MAX_HEADER_LENGTH: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " + f"'{revision or constants.DEFAULT_REVISION}'): safetensors header is too big. Maximum supported size is " + f"{constants.SAFETENSORS_MAX_HEADER_LENGTH} bytes (got {metadata_size})." + ) + + # 3.a. Get metadata from payload + if metadata_size <= 100000: + metadata_as_bytes = response.content[8 : 8 + metadata_size] + else: # 3.b. Request full metadata + response = get_session().get(url, headers={**_headers, "range": f"bytes=8-{metadata_size+7}"}) + hf_raise_for_status(response) + metadata_as_bytes = response.content + + # 4. Parse json header + try: + metadata_as_dict = json.loads(metadata_as_bytes.decode(errors="ignore")) + except json.JSONDecodeError as e: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " + f"'{revision or constants.DEFAULT_REVISION}'): header is not json-encoded string. Please make sure this is a " + "correctly formatted safetensors file." + ) from e + + try: + return SafetensorsFileMetadata( + metadata=metadata_as_dict.get("__metadata__", {}), + tensors={ + key: TensorInfo( + dtype=tensor["dtype"], + shape=tensor["shape"], + data_offsets=tuple(tensor["data_offsets"]), # type: ignore + ) + for key, tensor in metadata_as_dict.items() + if key != "__metadata__" + }, + ) + except (KeyError, IndexError) as e: + raise SafetensorsParsingError( + f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " + f"'{revision or constants.DEFAULT_REVISION}'): header format not recognized. Please make sure this is a correctly" + " formatted safetensors file." + ) from e + + @validate_hf_hub_args + def create_branch( + self, + repo_id: str, + *, + branch: str, + revision: Optional[str] = None, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + exist_ok: bool = False, + ) -> None: + """ + Create a new branch for a repo on the Hub, starting from the specified revision (defaults to `main`). + To find a revision suiting your needs, you can use [`list_repo_refs`] or [`list_repo_commits`]. + + Args: + repo_id (`str`): + The repository in which the branch will be created. + Example: `"user/my-cool-model"`. + + branch (`str`): + The name of the branch to create. + + revision (`str`, *optional*): + The git revision to create the branch from. It can be a branch name or + the OID/SHA of a commit, as a hexadecimal string. Defaults to the head + of the `"main"` branch. + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if creating a branch on a dataset or + space, `None` or `"model"` if tagging a model. Default is `None`. + + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if branch already exists. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.BadRequestError`]: + If invalid reference for a branch. Ex: `refs/pr/5` or 'refs/foo/bar'. + [`~utils.HfHubHTTPError`]: + If the branch already exists on the repo (error 409) and `exist_ok` is + set to `False`. + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + branch = quote(branch, safe="") + + # Prepare request + branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}" + headers = self._build_hf_headers(token=token) + payload = {} + if revision is not None: + payload["startingPoint"] = revision + + # Create branch + response = get_session().post(url=branch_url, headers=headers, json=payload) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + if exist_ok and e.response.status_code == 409: + return + elif exist_ok and e.response.status_code == 403: + # No write permission on the namespace but branch might already exist + try: + refs = self.list_repo_refs(repo_id=repo_id, repo_type=repo_type, token=token) + for branch_ref in refs.branches: + if branch_ref.name == branch: + return # Branch already exists => do not raise + except HfHubHTTPError: + pass # We raise the original error if the branch does not exist + raise + + @validate_hf_hub_args + def delete_branch( + self, + repo_id: str, + *, + branch: str, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Delete a branch from a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a branch will be deleted. + Example: `"user/my-cool-model"`. + + branch (`str`): + The name of the branch to delete. + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if creating a branch on a dataset or + space, `None` or `"model"` if tagging a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.HfHubHTTPError`]: + If trying to delete a protected branch. Ex: `main` cannot be deleted. + [`~utils.HfHubHTTPError`]: + If trying to delete a branch that does not exist. + + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + branch = quote(branch, safe="") + + # Prepare request + branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}" + headers = self._build_hf_headers(token=token) + + # Delete branch + response = get_session().delete(url=branch_url, headers=headers) + hf_raise_for_status(response) + + @validate_hf_hub_args + def create_tag( + self, + repo_id: str, + *, + tag: str, + tag_message: Optional[str] = None, + revision: Optional[str] = None, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + exist_ok: bool = False, + ) -> None: + """ + Tag a given commit of a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a commit will be tagged. + Example: `"user/my-cool-model"`. + + tag (`str`): + The name of the tag to create. + + tag_message (`str`, *optional*): + The description of the tag to create. + + revision (`str`, *optional*): + The git revision to tag. It can be a branch name or the OID/SHA of a + commit, as a hexadecimal string. Shorthands (7 first characters) are + also supported. Defaults to the head of the `"main"` branch. + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if tagging a dataset or + space, `None` or `"model"` if tagging a model. Default is + `None`. + + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if tag already exists. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + [`~utils.HfHubHTTPError`]: + If the branch already exists on the repo (error 409) and `exist_ok` is + set to `False`. + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION + + # Prepare request + tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{revision}" + headers = self._build_hf_headers(token=token) + payload = {"tag": tag} + if tag_message is not None: + payload["message"] = tag_message + + # Tag + response = get_session().post(url=tag_url, headers=headers, json=payload) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + if not (e.response.status_code == 409 and exist_ok): + raise + + @validate_hf_hub_args + def delete_tag( + self, + repo_id: str, + *, + tag: str, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> None: + """ + Delete a tag from a repo on the Hub. + + Args: + repo_id (`str`): + The repository in which a tag will be deleted. + Example: `"user/my-cool-model"`. + + tag (`str`): + The name of the tag to delete. + + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if tagging a dataset or space, `None` or + `"model"` if tagging a model. Default is `None`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private + but not authenticated or repo does not exist. + [`~utils.RevisionNotFoundError`]: + If tag is not found. + """ + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + tag = quote(tag, safe="") + + # Prepare request + tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{tag}" + headers = self._build_hf_headers(token=token) + + # Un-tag + response = get_session().delete(url=tag_url, headers=headers) + hf_raise_for_status(response) + + @validate_hf_hub_args + def get_full_repo_name( + self, + model_id: str, + *, + organization: Optional[str] = None, + token: Union[bool, str, None] = None, + ): + """ + Returns the repository name for a given model ID and optional + organization. + + Args: + model_id (`str`): + The name of the model. + organization (`str`, *optional*): + If passed, the repository name will be in the organization + namespace instead of the user namespace. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `str`: The repository name in the user's namespace + ({username}/{model_id}) if no organization is passed, and under the + organization namespace ({organization}/{model_id}) otherwise. + """ + if organization is None: + if "/" in model_id: + username = model_id.split("/")[0] + else: + username = self.whoami(token=token)["name"] # type: ignore + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + @validate_hf_hub_args + def get_repo_discussions( + self, + repo_id: str, + *, + author: Optional[str] = None, + discussion_type: Optional[constants.DiscussionTypeFilter] = None, + discussion_status: Optional[constants.DiscussionStatusFilter] = None, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Iterator[Discussion]: + """ + Fetches Discussions and Pull Requests for the given repo. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + author (`str`, *optional*): + Pass a value to filter by discussion author. `None` means no filter. + Default is `None`. + discussion_type (`str`, *optional*): + Set to `"pull_request"` to fetch only pull requests, `"discussion"` + to fetch only discussions. Set to `"all"` or `None` to fetch both. + Default is `None`. + discussion_status (`str`, *optional*): + Set to `"open"` (respectively `"closed"`) to fetch only open + (respectively closed) discussions. Set to `"all"` or `None` + to fetch both. + Default is `None`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if fetching from a dataset or + space, `None` or `"model"` if fetching from a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterator[Discussion]`: An iterator of [`Discussion`] objects. + + Example: + Collecting all discussions of a repo in a list: + + ```python + >>> from huggingface_hub import get_repo_discussions + >>> discussions_list = list(get_repo_discussions(repo_id="bert-base-uncased")) + ``` + + Iterating over discussions of a repo: + + ```python + >>> from huggingface_hub import get_repo_discussions + >>> for discussion in get_repo_discussions(repo_id="bert-base-uncased"): + ... print(discussion.num, discussion.title) + ``` + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + if discussion_type is not None and discussion_type not in constants.DISCUSSION_TYPES: + raise ValueError(f"Invalid discussion_type, must be one of {constants.DISCUSSION_TYPES}") + if discussion_status is not None and discussion_status not in constants.DISCUSSION_STATUS: + raise ValueError(f"Invalid discussion_status, must be one of {constants.DISCUSSION_STATUS}") + + headers = self._build_hf_headers(token=token) + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions" + + params: Dict[str, Union[str, int]] = {} + if discussion_type is not None: + params["type"] = discussion_type + if discussion_status is not None: + params["status"] = discussion_status + if author is not None: + params["author"] = author + + def _fetch_discussion_page(page_index: int): + params["p"] = page_index + resp = get_session().get(path, headers=headers, params=params) + hf_raise_for_status(resp) + paginated_discussions = resp.json() + total = paginated_discussions["count"] + start = paginated_discussions["start"] + discussions = paginated_discussions["discussions"] + has_next = (start + len(discussions)) < total + return discussions, has_next + + has_next, page_index = True, 0 + + while has_next: + discussions, has_next = _fetch_discussion_page(page_index=page_index) + for discussion in discussions: + yield Discussion( + title=discussion["title"], + num=discussion["num"], + author=discussion.get("author", {}).get("name", "deleted"), + created_at=parse_datetime(discussion["createdAt"]), + status=discussion["status"], + repo_id=discussion["repo"]["name"], + repo_type=discussion["repo"]["type"], + is_pull_request=discussion["isPullRequest"], + endpoint=self.endpoint, + ) + page_index = page_index + 1 + + @validate_hf_hub_args + def get_discussion_details( + self, + repo_id: str, + discussion_num: int, + *, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> DiscussionWithDetails: + """Fetches a Discussion's / Pull Request 's details from the Hub. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`DiscussionWithDetails`] + ++ + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + if not isinstance(discussion_num, int) or discussion_num <= 0: + raise ValueError("Invalid discussion_num, must be a positive integer") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions/{discussion_num}" + headers = self._build_hf_headers(token=token) + resp = get_session().get(path, params={"diff": "1"}, headers=headers) + hf_raise_for_status(resp) + + discussion_details = resp.json() + is_pull_request = discussion_details["isPullRequest"] + + target_branch = discussion_details["changes"]["base"] if is_pull_request else None + conflicting_files = discussion_details["filesWithConflicts"] if is_pull_request else None + merge_commit_oid = discussion_details["changes"].get("mergeCommitId", None) if is_pull_request else None + + return DiscussionWithDetails( + title=discussion_details["title"], + num=discussion_details["num"], + author=discussion_details.get("author", {}).get("name", "deleted"), + created_at=parse_datetime(discussion_details["createdAt"]), + status=discussion_details["status"], + repo_id=discussion_details["repo"]["name"], + repo_type=discussion_details["repo"]["type"], + is_pull_request=discussion_details["isPullRequest"], + events=[deserialize_event(evt) for evt in discussion_details["events"]], + conflicting_files=conflicting_files, + target_branch=target_branch, + merge_commit_oid=merge_commit_oid, + diff=discussion_details.get("diff"), + endpoint=self.endpoint, + ) + + @validate_hf_hub_args + def create_discussion( + self, + repo_id: str, + title: str, + *, + token: Union[bool, str, None] = None, + description: Optional[str] = None, + repo_type: Optional[str] = None, + pull_request: bool = False, + ) -> DiscussionWithDetails: + """Creates a Discussion or Pull Request. + + Pull Requests created programmatically will be in `"draft"` status. + + Creating a Pull Request with changes can also be done at once with [`HfApi.create_commit`]. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + title (`str`): + The title of the discussion. It can be up to 200 characters long, + and must be at least 3 characters long. Leading and trailing whitespaces + will be stripped. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + description (`str`, *optional*): + An optional description for the Pull Request. + Defaults to `"Discussion opened with the huggingface_hub Python library"` + pull_request (`bool`, *optional*): + Whether to create a Pull Request or discussion. If `True`, creates a Pull Request. + If `False`, creates a discussion. Defaults to `False`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: [`DiscussionWithDetails`] + ++ + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + if description is not None: + description = description.strip() + description = ( + description + if description + else ( + f"{'Pull Request' if pull_request else 'Discussion'} opened with the" + " [huggingface_hub Python" + " library](https://huggingface.co/docs/huggingface_hub)" + ) + ) + + headers = self._build_hf_headers(token=token) + resp = get_session().post( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions", + json={ + "title": title.strip(), + "description": description, + "pullRequest": pull_request, + }, + headers=headers, + ) + hf_raise_for_status(resp) + num = resp.json()["num"] + return self.get_discussion_details( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=num, + token=token, + ) + + @validate_hf_hub_args + def create_pull_request( + self, + repo_id: str, + title: str, + *, + token: Union[bool, str, None] = None, + description: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionWithDetails: + """Creates a Pull Request . Pull Requests created programmatically will be in `"draft"` status. + + Creating a Pull Request with changes can also be done at once with [`HfApi.create_commit`]; + + This is a wrapper around [`HfApi.create_discussion`]. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + title (`str`): + The title of the discussion. It can be up to 200 characters long, + and must be at least 3 characters long. Leading and trailing whitespaces + will be stripped. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + description (`str`, *optional*): + An optional description for the Pull Request. + Defaults to `"Discussion opened with the huggingface_hub Python library"` + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + + Returns: [`DiscussionWithDetails`] + ++ + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + """ + return self.create_discussion( + repo_id=repo_id, + title=title, + token=token, + description=description, + repo_type=repo_type, + pull_request=True, + ) + + def _post_discussion_changes( + self, + *, + repo_id: str, + discussion_num: int, + resource: str, + body: Optional[dict] = None, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> requests.Response: + """Internal utility to POST changes to a Discussion or Pull Request""" + if not isinstance(discussion_num, int) or discussion_num <= 0: + raise ValueError("Invalid discussion_num, must be a positive integer") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + repo_id = f"{repo_type}s/{repo_id}" + + path = f"{self.endpoint}/api/{repo_id}/discussions/{discussion_num}/{resource}" + + headers = self._build_hf_headers(token=token) + resp = requests.post(path, headers=headers, json=body) + hf_raise_for_status(resp) + return resp + + @validate_hf_hub_args + def comment_discussion( + self, + repo_id: str, + discussion_num: int, + comment: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Creates a new comment on the given Discussion. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment (`str`): + The content of the comment to create. Comments support markdown formatting. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionComment`]: the newly created comment + + + Examples: + ```python + + >>> comment = \"\"\" + ... Hello @otheruser! + ... + ... # This is a title + ... + ... **This is bold**, *this is italic* and ~this is strikethrough~ + ... And [this](http://url) is a link + ... \"\"\" + + >>> HfApi().comment_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... comment=comment + ... ) + # DiscussionComment(id='deadbeef0000000', type='comment', ...) + + ``` + ++ + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="comment", + body={"comment": comment}, + ) + return deserialize_event(resp.json()["newMessage"]) # type: ignore + + @validate_hf_hub_args + def rename_discussion( + self, + repo_id: str, + discussion_num: int, + new_title: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionTitleChange: + """Renames a Discussion. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + new_title (`str`): + The new title for the discussion + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionTitleChange`]: the title change event + + + Examples: + ```python + >>> new_title = "New title, fixing a typo" + >>> HfApi().rename_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... new_title=new_title + ... ) + # DiscussionTitleChange(id='deadbeef0000000', type='title-change', ...) + + ``` + ++ + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="title", + body={"title": new_title}, + ) + return deserialize_event(resp.json()["newTitle"]) # type: ignore + + @validate_hf_hub_args + def change_discussion_status( + self, + repo_id: str, + discussion_num: int, + new_status: Literal["open", "closed"], + *, + token: Union[bool, str, None] = None, + comment: Optional[str] = None, + repo_type: Optional[str] = None, + ) -> DiscussionStatusChange: + """Closes or re-opens a Discussion or Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + new_status (`str`): + The new status for the discussion, either `"open"` or `"closed"`. + comment (`str`, *optional*): + An optional comment to post with the status change. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionStatusChange`]: the status change event + + + Examples: + ```python + >>> new_title = "New title, fixing a typo" + >>> HfApi().rename_discussion( + ... repo_id="username/repo_name", + ... discussion_num=34 + ... new_title=new_title + ... ) + # DiscussionStatusChange(id='deadbeef0000000', type='status-change', ...) + + ``` + ++ + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + if new_status not in ["open", "closed"]: + raise ValueError("Invalid status, valid statuses are: 'open' and 'closed'") + body: Dict[str, str] = {"status": new_status} + if comment and comment.strip(): + body["comment"] = comment.strip() + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="status", + body=body, + ) + return deserialize_event(resp.json()["newStatus"]) # type: ignore + + @validate_hf_hub_args + def merge_pull_request( + self, + repo_id: str, + discussion_num: int, + *, + token: Union[bool, str, None] = None, + comment: Optional[str] = None, + repo_type: Optional[str] = None, + ): + """Merges a Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment (`str`, *optional*): + An optional comment to post with the status change. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionStatusChange`]: the status change event + ++ + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource="merge", + body={"comment": comment.strip()} if comment and comment.strip() else None, + ) + + @validate_hf_hub_args + def edit_discussion_comment( + self, + repo_id: str, + discussion_num: int, + comment_id: str, + new_content: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Edits a comment on a Discussion / Pull Request. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment_id (`str`): + The ID of the comment to edit. + new_content (`str`): + The new content of the comment. Comments support markdown formatting. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionComment`]: the edited comment + ++ + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource=f"comment/{comment_id.lower()}/edit", + body={"content": new_content}, + ) + return deserialize_event(resp.json()["updatedComment"]) # type: ignore + + @validate_hf_hub_args + def hide_discussion_comment( + self, + repo_id: str, + discussion_num: int, + comment_id: str, + *, + token: Union[bool, str, None] = None, + repo_type: Optional[str] = None, + ) -> DiscussionComment: + """Hides a comment on a Discussion / Pull Request. + ++ Hidden comments' content cannot be retrieved anymore. Hiding a comment is irreversible. + + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + discussion_num (`int`): + The number of the Discussion or Pull Request . Must be a strictly positive integer. + comment_id (`str`): + The ID of the comment to edit. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or + space, `None` or `"model"` if uploading to a model. Default is + `None`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`DiscussionComment`]: the hidden comment + ++ + Raises the following errors: + + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the HuggingFace API returned an error + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + + + """ + warnings.warn( + "Hidden comments' content cannot be retrieved anymore. Hiding a comment is irreversible.", + UserWarning, + ) + resp = self._post_discussion_changes( + repo_id=repo_id, + repo_type=repo_type, + discussion_num=discussion_num, + token=token, + resource=f"comment/{comment_id.lower()}/hide", + ) + return deserialize_event(resp.json()["updatedComment"]) # type: ignore + + @validate_hf_hub_args + def add_space_secret( + self, + repo_id: str, + key: str, + value: str, + *, + description: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + """Adds or updates a secret in a Space. + + Secrets allow to set secret keys or tokens to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Secret key. Example: `"GITHUB_API_KEY"` + value (`str`): + Secret value. Example: `"your_github_api_key"`. + description (`str`, *optional*): + Secret description. Example: `"Github API key to access the Github API"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + payload = {"key": key, "value": value} + if description is not None: + payload["description"] = description + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/secrets", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + + @validate_hf_hub_args + def delete_space_secret(self, repo_id: str, key: str, *, token: Union[bool, str, None] = None) -> None: + """Deletes a secret from a Space. + + Secrets allow to set secret keys or tokens to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Secret key. Example: `"GITHUB_API_KEY"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + r = get_session().delete( + f"{self.endpoint}/api/spaces/{repo_id}/secrets", + headers=self._build_hf_headers(token=token), + json={"key": key}, + ) + hf_raise_for_status(r) + + @validate_hf_hub_args + def get_space_variables(self, repo_id: str, *, token: Union[bool, str, None] = None) -> Dict[str, SpaceVariable]: + """Gets all variables from a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to query. Example: `"bigcode/in-the-stack"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + r = get_session().get( + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def add_space_variable( + self, + repo_id: str, + key: str, + value: str, + *, + description: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Dict[str, SpaceVariable]: + """Adds or updates a variable in a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Variable key. Example: `"MODEL_REPO_ID"` + value (`str`): + Variable value. Example: `"the_model_repo_id"`. + description (`str`): + Description of the variable. Example: `"Model Repo ID of the implemented model"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + payload = {"key": key, "value": value} + if description is not None: + payload["description"] = description + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def delete_space_variable( + self, repo_id: str, key: str, *, token: Union[bool, str, None] = None + ) -> Dict[str, SpaceVariable]: + """Deletes a variable from a Space. + + Variables allow to set environment variables to a Space without hardcoding them. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + key (`str`): + Variable key. Example: `"MODEL_REPO_ID"` + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + r = get_session().delete( + f"{self.endpoint}/api/spaces/{repo_id}/variables", + headers=self._build_hf_headers(token=token), + json={"key": key}, + ) + hf_raise_for_status(r) + return {k: SpaceVariable(k, v) for k, v in r.json().items()} + + @validate_hf_hub_args + def get_space_runtime(self, repo_id: str, *, token: Union[bool, str, None] = None) -> SpaceRuntime: + """Gets runtime information about a Space. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + """ + r = get_session().get( + f"{self.endpoint}/api/spaces/{repo_id}/runtime", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def request_space_hardware( + self, + repo_id: str, + hardware: SpaceHardware, + *, + token: Union[bool, str, None] = None, + sleep_time: Optional[int] = None, + ) -> SpaceRuntime: + """Request new hardware for a Space. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + hardware (`str` or [`SpaceHardware`]): + Hardware on which to run the Space. Example: `"t4-medium"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + ++ + It is also possible to request hardware directly when creating the Space repo! See [`create_repo`] for details. + + + """ + if sleep_time is not None and hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + payload: Dict[str, Any] = {"flavor": hardware} + if sleep_time is not None: + payload["sleepTimeSeconds"] = sleep_time + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/hardware", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def set_space_sleep_time( + self, repo_id: str, sleep_time: int, *, token: Union[bool, str, None] = None + ) -> SpaceRuntime: + """Set a custom sleep time for a Space running on upgraded hardware.. + + Your Space will go to sleep after X seconds of inactivity. You are not billed when your Space is in "sleep" + mode. If a new visitor lands on your Space, it will "wake it up". Only upgraded hardware can have a + configurable sleep time. To know more about the sleep stage, please refer to + https://huggingface.co/docs/hub/spaces-gpus#sleep-time. + + Args: + repo_id (`str`): + ID of the repo to update. Example: `"bigcode/in-the-stack"`. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to pause (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + ++ + It is also possible to set a custom sleep time when requesting hardware with [`request_space_hardware`]. + + + """ + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/sleeptime", + headers=self._build_hf_headers(token=token), + json={"seconds": sleep_time}, + ) + hf_raise_for_status(r) + runtime = SpaceRuntime(r.json()) + + hardware = runtime.requested_hardware or runtime.hardware + if hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + return runtime + + @validate_hf_hub_args + def pause_space(self, repo_id: str, *, token: Union[bool, str, None] = None) -> SpaceRuntime: + """Pause your Space. + + A paused Space stops executing until manually restarted by its owner. This is different from the sleeping + state in which free Spaces go after 48h of inactivity. Paused time is not billed to your account, no matter the + hardware you've selected. To restart your Space, use [`restart_space`] and go to your Space settings page. + + For more details, please visit [the docs](https://huggingface.co/docs/hub/spaces-gpus#pause). + + Args: + repo_id (`str`): + ID of the Space to pause. Example: `"Salesforce/BLIP2"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`SpaceRuntime`]: Runtime information about your Space including `stage=PAUSED` and requested hardware. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If your Space is not found (error 404). Most probably wrong repo_id or your space is private but you + are not authenticated. + [`~utils.HfHubHTTPError`]: + 403 Forbidden: only the owner of a Space can pause it. If you want to manage a Space that you don't + own, either ask the owner by opening a Discussion or duplicate the Space. + [`~utils.BadRequestError`]: + If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide + a static Space, you can set it to private. + """ + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/pause", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def restart_space( + self, repo_id: str, *, token: Union[bool, str, None] = None, factory_reboot: bool = False + ) -> SpaceRuntime: + """Restart your Space. + + This is the only way to programmatically restart a Space if you've put it on Pause (see [`pause_space`]). You + must be the owner of the Space to restart it. If you are using an upgraded hardware, your account will be + billed as soon as the Space is restarted. You can trigger a restart no matter the current state of a Space. + + For more details, please visit [the docs](https://huggingface.co/docs/hub/spaces-gpus#pause). + + Args: + repo_id (`str`): + ID of the Space to restart. Example: `"Salesforce/BLIP2"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + factory_reboot (`bool`, *optional*): + If `True`, the Space will be rebuilt from scratch without caching any requirements. + + Returns: + [`SpaceRuntime`]: Runtime information about your Space. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If your Space is not found (error 404). Most probably wrong repo_id or your space is private but you + are not authenticated. + [`~utils.HfHubHTTPError`]: + 403 Forbidden: only the owner of a Space can restart it. If you want to restart a Space that you don't + own, either ask the owner by opening a Discussion or duplicate the Space. + [`~utils.BadRequestError`]: + If your Space is a static Space. Static Spaces are always running and never billed. If you want to hide + a static Space, you can set it to private. + """ + params = {} + if factory_reboot: + params["factory"] = "true" + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/restart", headers=self._build_hf_headers(token=token), params=params + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def duplicate_space( + self, + from_id: str, + to_id: Optional[str] = None, + *, + private: Optional[bool] = None, + token: Union[bool, str, None] = None, + exist_ok: bool = False, + hardware: Optional[SpaceHardware] = None, + storage: Optional[SpaceStorage] = None, + sleep_time: Optional[int] = None, + secrets: Optional[List[Dict[str, str]]] = None, + variables: Optional[List[Dict[str, str]]] = None, + ) -> RepoUrl: + """Duplicate a Space. + + Programmatically duplicate a Space. The new Space will be created in your account and will be in the same state + as the original Space (running or paused). You can duplicate a Space no matter the current state of a Space. + + Args: + from_id (`str`): + ID of the Space to duplicate. Example: `"pharma/CLIP-Interrogator"`. + to_id (`str`, *optional*): + ID of the new Space. Example: `"dog/CLIP-Interrogator"`. If not provided, the new Space will have the same + name as the original Space, but in your account. + private (`bool`, *optional*): + Whether the new Space should be private or not. Defaults to the same privacy as the original Space. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + exist_ok (`bool`, *optional*, defaults to `False`): + If `True`, do not raise an error if repo already exists. + hardware (`SpaceHardware` or `str`, *optional*): + Choice of Hardware. Example: `"t4-medium"`. See [`SpaceHardware`] for a complete list. + storage (`SpaceStorage` or `str`, *optional*): + Choice of persistent storage tier. Example: `"small"`. See [`SpaceStorage`] for a complete list. + sleep_time (`int`, *optional*): + Number of seconds of inactivity to wait before a Space is put to sleep. Set to `-1` if you don't want + your Space to sleep (default behavior for upgraded hardware). For free hardware, you can't configure + the sleep time (value is fixed to 48 hours of inactivity). + See https://huggingface.co/docs/hub/spaces-gpus#sleep-time for more details. + secrets (`List[Dict[str, str]]`, *optional*): + A list of secret keys to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets. + variables (`List[Dict[str, str]]`, *optional*): + A list of public environment variables to set in your Space. Each item is in the form `{"key": ..., "value": ..., "description": ...}` where description is optional. + For more details, see https://huggingface.co/docs/hub/spaces-overview#managing-secrets-and-environment-variables. + + Returns: + [`RepoUrl`]: URL to the newly created repo. Value is a subclass of `str` containing + attributes like `endpoint`, `repo_type` and `repo_id`. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If one of `from_id` or `to_id` cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + If the HuggingFace API returned an error + + Example: + ```python + >>> from huggingface_hub import duplicate_space + + # Duplicate a Space to your account + >>> duplicate_space("multimodalart/dreambooth-training") + RepoUrl('https://huggingface.co/spaces/nateraw/dreambooth-training',...) + + # Can set custom destination id and visibility flag. + >>> duplicate_space("multimodalart/dreambooth-training", to_id="my-dreambooth", private=True) + RepoUrl('https://huggingface.co/spaces/nateraw/my-dreambooth',...) + ``` + """ + # Parse to_id if provided + parsed_to_id = RepoUrl(to_id) if to_id is not None else None + + # Infer target repo_id + to_namespace = ( # set namespace manually or default to username + parsed_to_id.namespace + if parsed_to_id is not None and parsed_to_id.namespace is not None + else self.whoami(token)["name"] + ) + to_repo_name = parsed_to_id.repo_name if to_id is not None else RepoUrl(from_id).repo_name # type: ignore + + # repository must be a valid repo_id (namespace/repo_name). + payload: Dict[str, Any] = {"repository": f"{to_namespace}/{to_repo_name}"} + + keys = ["private", "hardware", "storageTier", "sleepTimeSeconds", "secrets", "variables"] + values = [private, hardware, storage, sleep_time, secrets, variables] + payload.update({k: v for k, v in zip(keys, values) if v is not None}) + + if sleep_time is not None and hardware == SpaceHardware.CPU_BASIC: + warnings.warn( + "If your Space runs on the default 'cpu-basic' hardware, it will go to sleep if inactive for more" + " than 48 hours. This value is not configurable. If you don't want your Space to deactivate or if" + " you want to set a custom sleep time, you need to upgrade to a paid Hardware.", + UserWarning, + ) + + r = get_session().post( + f"{self.endpoint}/api/spaces/{from_id}/duplicate", + headers=self._build_hf_headers(token=token), + json=payload, + ) + + try: + hf_raise_for_status(r) + except HTTPError as err: + if exist_ok and err.response.status_code == 409: + # Repo already exists and `exist_ok=True` + pass + else: + raise + + return RepoUrl(r.json()["url"], endpoint=self.endpoint) + + @validate_hf_hub_args + def request_space_storage( + self, + repo_id: str, + storage: SpaceStorage, + *, + token: Union[bool, str, None] = None, + ) -> SpaceRuntime: + """Request persistent storage for a Space. + + Args: + repo_id (`str`): + ID of the Space to update. Example: `"open-llm-leaderboard/open_llm_leaderboard"`. + storage (`str` or [`SpaceStorage`]): + Storage tier. Either 'small', 'medium', or 'large'. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + ++ + It is not possible to decrease persistent storage after its granted. To do so, you must delete it + via [`delete_space_storage`]. + + + """ + payload: Dict[str, SpaceStorage] = {"tier": storage} + r = get_session().post( + f"{self.endpoint}/api/spaces/{repo_id}/storage", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + @validate_hf_hub_args + def delete_space_storage( + self, + repo_id: str, + *, + token: Union[bool, str, None] = None, + ) -> SpaceRuntime: + """Delete persistent storage for a Space. + + Args: + repo_id (`str`): + ID of the Space to update. Example: `"open-llm-leaderboard/open_llm_leaderboard"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + Returns: + [`SpaceRuntime`]: Runtime information about a Space including Space stage and hardware. + Raises: + [`BadRequestError`] + If space has no persistent storage. + + """ + r = get_session().delete( + f"{self.endpoint}/api/spaces/{repo_id}/storage", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(r) + return SpaceRuntime(r.json()) + + ####################### + # Inference Endpoints # + ####################### + + def list_inference_endpoints( + self, namespace: Optional[str] = None, *, token: Union[bool, str, None] = None + ) -> List[InferenceEndpoint]: + """Lists all inference endpoints for the given namespace. + + Args: + namespace (`str`, *optional*): + The namespace to list endpoints for. Defaults to the current user. Set to `"*"` to list all endpoints + from all namespaces (i.e. personal namespace and all orgs the user belongs to). + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + List[`InferenceEndpoint`]: A list of all inference endpoints for the given namespace. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> api.list_inference_endpoints() + [InferenceEndpoint(name='my-endpoint', ...), ...] + ``` + """ + # Special case: list all endpoints for all namespaces the user has access to + if namespace == "*": + user = self.whoami(token=token) + + # List personal endpoints first + endpoints: List[InferenceEndpoint] = list_inference_endpoints(namespace=self._get_namespace(token=token)) + + # Then list endpoints for all orgs the user belongs to and ignore 401 errors (no billing or no access) + for org in user.get("orgs", []): + try: + endpoints += list_inference_endpoints(namespace=org["name"], token=token) + except HfHubHTTPError as error: + if error.response.status_code == 401: # Either no billing or user don't have access) + logger.debug("Cannot list Inference Endpoints for org '%s': %s", org["name"], error) + pass + + return endpoints + + # Normal case: list endpoints for a specific namespace + namespace = namespace or self._get_namespace(token=token) + + response = get_session().get( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return [ + InferenceEndpoint.from_raw(endpoint, namespace=namespace, token=token) + for endpoint in response.json()["items"] + ] + + def create_inference_endpoint( + self, + name: str, + *, + repository: str, + framework: str, + accelerator: str, + instance_size: str, + instance_type: str, + region: str, + vendor: str, + account_id: Optional[str] = None, + min_replica: int = 0, + max_replica: int = 1, + scale_to_zero_timeout: int = 15, + revision: Optional[str] = None, + task: Optional[str] = None, + custom_image: Optional[Dict] = None, + secrets: Optional[Dict[str, str]] = None, + type: InferenceEndpointType = InferenceEndpointType.PROTECTED, + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> InferenceEndpoint: + """Create a new Inference Endpoint. + + Args: + name (`str`): + The unique name for the new Inference Endpoint. + repository (`str`): + The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). + framework (`str`): + The machine learning framework used for the model (e.g. `"custom"`). + accelerator (`str`): + The hardware accelerator to be used for inference (e.g. `"cpu"`). + instance_size (`str`): + The size or type of the instance to be used for hosting the model (e.g. `"x4"`). + instance_type (`str`): + The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). + region (`str`): + The cloud region in which the Inference Endpoint will be created (e.g. `"us-east-1"`). + vendor (`str`): + The cloud provider or vendor where the Inference Endpoint will be hosted (e.g. `"aws"`). + account_id (`str`, *optional*): + The account ID used to link a VPC to a private Inference Endpoint (if applicable). + min_replica (`int`, *optional*): + The minimum number of replicas (instances) to keep running for the Inference Endpoint. Defaults to 0. + max_replica (`int`, *optional*): + The maximum number of replicas (instances) to scale to for the Inference Endpoint. Defaults to 1. + scale_to_zero_timeout (`int`, *optional*): + The duration in minutes before an inactive endpoint is scaled to zero. Defaults to 15. + revision (`str`, *optional*): + The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). + task (`str`, *optional*): + The task on which to deploy the model (e.g. `"text-classification"`). + custom_image (`Dict`, *optional*): + A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an + Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). + secrets (`Dict[str, str]`, *optional*): + Secret values to inject in the container environment. + type ([`InferenceEndpointType]`, *optional*): + The type of the Inference Endpoint, which can be `"protected"` (default), `"public"` or `"private"`. + namespace (`str`, *optional*): + The namespace where the Inference Endpoint will be created. Defaults to the current user's namespace. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the updated Inference Endpoint. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.create_inference_endpoint( + ... "my-endpoint-name", + ... repository="gpt2", + ... framework="pytorch", + ... task="text-generation", + ... accelerator="cpu", + ... vendor="aws", + ... region="us-east-1", + ... type="protected", + ... instance_size="x2", + ... instance_type="intel-icl", + ... ) + >>> endpoint + InferenceEndpoint(name='my-endpoint-name', status="pending",...) + + # Run inference on the endpoint + >>> endpoint.client.text_generation(...) + "..." + ``` + + ```python + # Start an Inference Endpoint running Zephyr-7b-beta on TGI + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.create_inference_endpoint( + ... "aws-zephyr-7b-beta-0486", + ... repository="HuggingFaceH4/zephyr-7b-beta", + ... framework="pytorch", + ... task="text-generation", + ... accelerator="gpu", + ... vendor="aws", + ... region="us-east-1", + ... type="protected", + ... instance_size="x1", + ... instance_type="nvidia-a10g", + ... custom_image={ + ... "health_route": "/health", + ... "env": { + ... "MAX_BATCH_PREFILL_TOKENS": "2048", + ... "MAX_INPUT_LENGTH": "1024", + ... "MAX_TOTAL_TOKENS": "1512", + ... "MODEL_ID": "/repository" + ... }, + ... "url": "ghcr.io/huggingface/text-generation-inference:1.1.0", + ... }, + ... secrets={"MY_SECRET_KEY": "secret_value"}, + ... ) + + ``` + """ + namespace = namespace or self._get_namespace(token=token) + + image = {"custom": custom_image} if custom_image is not None else {"huggingface": {}} + payload: Dict = { + "accountId": account_id, + "compute": { + "accelerator": accelerator, + "instanceSize": instance_size, + "instanceType": instance_type, + "scaling": { + "maxReplica": max_replica, + "minReplica": min_replica, + "scaleToZeroTimeout": scale_to_zero_timeout, + }, + }, + "model": { + "framework": framework, + "repository": repository, + "revision": revision, + "task": task, + "image": image, + }, + "name": name, + "provider": { + "region": region, + "vendor": vendor, + }, + "type": type, + } + if secrets: + payload["model"]["secrets"] = secrets + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def get_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> InferenceEndpoint: + """Get information about an Inference Endpoint. + + Args: + name (`str`): + The name of the Inference Endpoint to retrieve information about. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the requested Inference Endpoint. + + Example: + ```python + >>> from huggingface_hub import HfApi + >>> api = HfApi() + >>> endpoint = api.get_inference_endpoint("my-text-to-image") + >>> endpoint + InferenceEndpoint(name='my-text-to-image', ...) + + # Get status + >>> endpoint.status + 'running' + >>> endpoint.url + 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud' + + # Run inference + >>> endpoint.client.text_to_image(...) + ``` + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().get( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def update_inference_endpoint( + self, + name: str, + *, + # Compute update + accelerator: Optional[str] = None, + instance_size: Optional[str] = None, + instance_type: Optional[str] = None, + min_replica: Optional[int] = None, + max_replica: Optional[int] = None, + scale_to_zero_timeout: Optional[int] = None, + # Model update + repository: Optional[str] = None, + framework: Optional[str] = None, + revision: Optional[str] = None, + task: Optional[str] = None, + custom_image: Optional[Dict] = None, + secrets: Optional[Dict[str, str]] = None, + # Other + namespace: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> InferenceEndpoint: + """Update an Inference Endpoint. + + This method allows the update of either the compute configuration, the deployed model, or both. All arguments are + optional but at least one must be provided. + + For convenience, you can also update an Inference Endpoint using [`InferenceEndpoint.update`]. + + Args: + name (`str`): + The name of the Inference Endpoint to update. + + accelerator (`str`, *optional*): + The hardware accelerator to be used for inference (e.g. `"cpu"`). + instance_size (`str`, *optional*): + The size or type of the instance to be used for hosting the model (e.g. `"x4"`). + instance_type (`str`, *optional*): + The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). + min_replica (`int`, *optional*): + The minimum number of replicas (instances) to keep running for the Inference Endpoint. + max_replica (`int`, *optional*): + The maximum number of replicas (instances) to scale to for the Inference Endpoint. + scale_to_zero_timeout (`int`, *optional*): + The duration in minutes before an inactive endpoint is scaled to zero. + + repository (`str`, *optional*): + The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). + framework (`str`, *optional*): + The machine learning framework used for the model (e.g. `"custom"`). + revision (`str`, *optional*): + The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). + task (`str`, *optional*): + The task on which to deploy the model (e.g. `"text-classification"`). + custom_image (`Dict`, *optional*): + A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an + Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). + secrets (`Dict[str, str]`, *optional*): + Secret values to inject in the container environment. + namespace (`str`, *optional*): + The namespace where the Inference Endpoint will be updated. Defaults to the current user's namespace. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the updated Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + # Populate only the fields that are not None + payload: Dict = defaultdict(lambda: defaultdict(dict)) + if accelerator is not None: + payload["compute"]["accelerator"] = accelerator + if instance_size is not None: + payload["compute"]["instanceSize"] = instance_size + if instance_type is not None: + payload["compute"]["instanceType"] = instance_type + if max_replica is not None: + payload["compute"]["scaling"]["maxReplica"] = max_replica + if min_replica is not None: + payload["compute"]["scaling"]["minReplica"] = min_replica + if scale_to_zero_timeout is not None: + payload["compute"]["scaling"]["scaleToZeroTimeout"] = scale_to_zero_timeout + if repository is not None: + payload["model"]["repository"] = repository + if framework is not None: + payload["model"]["framework"] = framework + if revision is not None: + payload["model"]["revision"] = revision + if task is not None: + payload["model"]["task"] = task + if custom_image is not None: + payload["model"]["image"] = {"custom": custom_image} + if secrets is not None: + payload["model"]["secrets"] = secrets + + response = get_session().put( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + json=payload, + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def delete_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """Delete an Inference Endpoint. + + This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable + to pause it with [`pause_inference_endpoint`] or scale it to zero with [`scale_to_zero_inference_endpoint`]. + + For convenience, you can also delete an Inference Endpoint using [`InferenceEndpoint.delete`]. + + Args: + name (`str`): + The name of the Inference Endpoint to delete. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + """ + namespace = namespace or self._get_namespace(token=token) + response = get_session().delete( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + def pause_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> InferenceEndpoint: + """Pause an Inference Endpoint. + + A paused Inference Endpoint will not be charged. It can be resumed at any time using [`resume_inference_endpoint`]. + This is different than scaling the Inference Endpoint to zero with [`scale_to_zero_inference_endpoint`], which + would be automatically restarted when a request is made to it. + + For convenience, you can also pause an Inference Endpoint using [`pause_inference_endpoint`]. + + Args: + name (`str`): + The name of the Inference Endpoint to pause. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the paused Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/pause", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def resume_inference_endpoint( + self, + name: str, + *, + namespace: Optional[str] = None, + running_ok: bool = True, + token: Union[bool, str, None] = None, + ) -> InferenceEndpoint: + """Resume an Inference Endpoint. + + For convenience, you can also resume an Inference Endpoint using [`InferenceEndpoint.resume`]. + + Args: + name (`str`): + The name of the Inference Endpoint to resume. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + running_ok (`bool`, *optional*): + If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to + `True`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the resumed Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/resume", + headers=self._build_hf_headers(token=token), + ) + try: + hf_raise_for_status(response) + except HfHubHTTPError as error: + # If already running (and it's ok), then fetch current status and return + if running_ok and error.response.status_code == 400 and "already running" in error.response.text: + return self.get_inference_endpoint(name, namespace=namespace, token=token) + # Otherwise, raise the error + raise + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def scale_to_zero_inference_endpoint( + self, name: str, *, namespace: Optional[str] = None, token: Union[bool, str, None] = None + ) -> InferenceEndpoint: + """Scale Inference Endpoint to zero. + + An Inference Endpoint scaled to zero will not be charged. It will be resume on the next request to it, with a + cold start delay. This is different than pausing the Inference Endpoint with [`pause_inference_endpoint`], which + would require a manual resume with [`resume_inference_endpoint`]. + + For convenience, you can also scale an Inference Endpoint to zero using [`InferenceEndpoint.scale_to_zero`]. + + Args: + name (`str`): + The name of the Inference Endpoint to scale to zero. + namespace (`str`, *optional*): + The namespace in which the Inference Endpoint is located. Defaults to the current user. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`InferenceEndpoint`]: information about the scaled-to-zero Inference Endpoint. + """ + namespace = namespace or self._get_namespace(token=token) + + response = get_session().post( + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/scale-to-zero", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + return InferenceEndpoint.from_raw(response.json(), namespace=namespace, token=token) + + def _get_namespace(self, token: Union[bool, str, None] = None) -> str: + """Get the default namespace for the current user.""" + me = self.whoami(token=token) + if me["type"] == "user": + return me["name"] + else: + raise ValueError( + "Cannot determine default namespace. You must provide a 'namespace' as input or be logged in as a" + " user." + ) + + ######################## + # Collection Endpoints # + ######################## + @validate_hf_hub_args + def list_collections( + self, + *, + owner: Union[List[str], str, None] = None, + item: Union[List[str], str, None] = None, + sort: Optional[Literal["lastModified", "trending", "upvotes"]] = None, + limit: Optional[int] = None, + token: Union[bool, str, None] = None, + ) -> Iterable[Collection]: + """List collections on the Huggingface Hub, given some filters. + ++ + When listing collections, the item list per collection is truncated to 4 items maximum. To retrieve all items + from a collection, you must use [`get_collection`]. + + + + Args: + owner (`List[str]` or `str`, *optional*): + Filter by owner's username. + item (`List[str]` or `str`, *optional*): + Filter collections containing a particular items. Example: `"models/teknium/OpenHermes-2.5-Mistral-7B"`, `"datasets/squad"` or `"papers/2311.12983"`. + sort (`Literal["lastModified", "trending", "upvotes"]`, *optional*): + Sort collections by last modified, trending or upvotes. + limit (`int`, *optional*): + Maximum number of collections to be returned. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[Collection]`: an iterable of [`Collection`] objects. + """ + # Construct the API endpoint + path = f"{self.endpoint}/api/collections" + headers = self._build_hf_headers(token=token) + params: Dict = {} + if owner is not None: + params.update({"owner": owner}) + if item is not None: + params.update({"item": item}) + if sort is not None: + params.update({"sort": sort}) + if limit is not None: + params.update({"limit": limit}) + + # Paginate over the results until limit is reached + items = paginate(path, headers=headers, params=params) + if limit is not None: + items = islice(items, limit) # Do not iterate over all pages + + # Parse as Collection and return + for position, collection_data in enumerate(items): + yield Collection(position=position, **collection_data) + + def get_collection(self, collection_slug: str, *, token: Union[bool, str, None] = None) -> Collection: + """Gets information about a Collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection of the Hub. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import get_collection + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + >>> collection.title + 'Recent models' + >>> len(collection.items) + 37 + >>> collection.items[0] + CollectionItem( + item_object_id='651446103cd773a050bf64c2', + item_id='TheBloke/U-Amethyst-20B-AWQ', + item_type='model', + position=88, + note=None + ) + ``` + """ + r = get_session().get( + f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def create_collection( + self, + title: str, + *, + namespace: Optional[str] = None, + description: Optional[str] = None, + private: bool = False, + exists_ok: bool = False, + token: Union[bool, str, None] = None, + ) -> Collection: + """Create a new Collection on the Hub. + + Args: + title (`str`): + Title of the collection to create. Example: `"Recent models"`. + namespace (`str`, *optional*): + Namespace of the collection to create (username or org). Will default to the owner name. + description (`str`, *optional*): + Description of the collection to create. + private (`bool`, *optional*): + Whether the collection should be private or not. Defaults to `False` (i.e. public collection). + exists_ok (`bool`, *optional*): + If `True`, do not raise an error if collection already exists. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import create_collection + >>> collection = create_collection( + ... title="ICCV 2023", + ... description="Portfolio of models, papers and demos I presented at ICCV 2023", + ... ) + >>> collection.slug + "username/iccv-2023-64f9a55bb3115b4f513ec026" + ``` + """ + if namespace is None: + namespace = self.whoami(token)["name"] + + payload = { + "title": title, + "namespace": namespace, + "private": private, + } + if description is not None: + payload["description"] = description + + r = get_session().post( + f"{self.endpoint}/api/collections", headers=self._build_hf_headers(token=token), json=payload + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if exists_ok and err.response.status_code == 409: + # Collection already exists and `exists_ok=True` + slug = r.json()["slug"] + return self.get_collection(slug, token=token) + else: + raise + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def update_collection_metadata( + self, + collection_slug: str, + *, + title: Optional[str] = None, + description: Optional[str] = None, + position: Optional[int] = None, + private: Optional[bool] = None, + theme: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> Collection: + """Update metadata of a collection on the Hub. + + All arguments are optional. Only provided metadata will be updated. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + title (`str`): + Title of the collection to update. + description (`str`, *optional*): + Description of the collection to update. + position (`int`, *optional*): + New position of the collection in the list of collections of the user. + private (`bool`, *optional*): + Whether the collection should be private or not. + theme (`str`, *optional*): + Theme of the collection on the Hub. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import update_collection_metadata + >>> collection = update_collection_metadata( + ... collection_slug="username/iccv-2023-64f9a55bb3115b4f513ec026", + ... title="ICCV Oct. 2023" + ... description="Portfolio of models, datasets, papers and demos I presented at ICCV Oct. 2023", + ... private=False, + ... theme="pink", + ... ) + >>> collection.slug + "username/iccv-oct-2023-64f9a55bb3115b4f513ec026" + # ^collection slug got updated but not the trailing ID + ``` + """ + payload = { + "position": position, + "private": private, + "theme": theme, + "title": title, + "description": description, + } + r = get_session().patch( + f"{self.endpoint}/api/collections/{collection_slug}", + headers=self._build_hf_headers(token=token), + # Only send not-none values to the API + json={key: value for key, value in payload.items() if value is not None}, + ) + hf_raise_for_status(r) + return Collection(**{**r.json()["data"], "endpoint": self.endpoint}) + + def delete_collection( + self, collection_slug: str, *, missing_ok: bool = False, token: Union[bool, str, None] = None + ) -> None: + """Delete a collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection to delete. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + missing_ok (`bool`, *optional*): + If `True`, do not raise an error if collection doesn't exists. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + + ```py + >>> from huggingface_hub import delete_collection + >>> collection = delete_collection("username/useless-collection-64f9a55bb3115b4f513ec026", missing_ok=True) + ``` + ++ + This is a non-revertible action. A deleted collection cannot be restored. + + + """ + r = get_session().delete( + f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if missing_ok and err.response.status_code == 404: + # Collection doesn't exists and `missing_ok=True` + return + else: + raise + + def add_collection_item( + self, + collection_slug: str, + item_id: str, + item_type: CollectionItemType_T, + *, + note: Optional[str] = None, + exists_ok: bool = False, + token: Union[bool, str, None] = None, + ) -> Collection: + """Add an item to a collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_id (`str`): + ID of the item to add to the collection. It can be the ID of a repo on the Hub (e.g. `"facebook/bart-large-mnli"`) + or a paper id (e.g. `"2307.09288"`). + item_type (`str`): + Type of the item to add. Can be one of `"model"`, `"dataset"`, `"space"` or `"paper"`. + note (`str`, *optional*): + A note to attach to the item in the collection. The maximum size for a note is 500 characters. + exists_ok (`bool`, *optional*): + If `True`, do not raise an error if item already exists. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: [`Collection`] + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the item you try to add to the collection does not exist on the Hub. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 409 if the item you try to add to the collection is already in the collection (and exists_ok=False) + + Example: + + ```py + >>> from huggingface_hub import add_collection_item + >>> collection = add_collection_item( + ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", + ... item_id="pierre-loic/climate-news-articles", + ... item_type="dataset" + ... ) + >>> collection.items[-1].item_id + "pierre-loic/climate-news-articles" + # ^item got added to the collection on last position + + # Add item with a note + >>> add_collection_item( + ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", + ... item_id="datasets/climate_fever", + ... item_type="dataset" + ... note="This dataset adopts the FEVER methodology that consists of 1,535 real-world claims regarding climate-change collected on the internet." + ... ) + (...) + ``` + """ + payload: Dict[str, Any] = {"item": {"id": item_id, "type": item_type}} + if note is not None: + payload["note"] = note + r = get_session().post( + f"{self.endpoint}/api/collections/{collection_slug}/items", + headers=self._build_hf_headers(token=token), + json=payload, + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if exists_ok and err.response.status_code == 409: + # Item already exists and `exists_ok=True` + return self.get_collection(collection_slug, token=token) + else: + raise + return Collection(**{**r.json(), "endpoint": self.endpoint}) + + def update_collection_item( + self, + collection_slug: str, + item_object_id: str, + *, + note: Optional[str] = None, + position: Optional[int] = None, + token: Union[bool, str, None] = None, + ) -> None: + """Update an item in a collection. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_object_id (`str`): + ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). + It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0].item_object_id`. + note (`str`, *optional*): + A note to attach to the item in the collection. The maximum size for a note is 500 characters. + position (`int`, *optional*): + New position of the item in the collection. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + + ```py + >>> from huggingface_hub import get_collection, update_collection_item + + # Get collection first + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + + # Update item based on its ID (add note + update position) + >>> update_collection_item( + ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", + ... item_object_id=collection.items[-1].item_object_id, + ... note="Newly updated model!" + ... position=0, + ... ) + ``` + """ + payload = {"position": position, "note": note} + r = get_session().patch( + f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", + headers=self._build_hf_headers(token=token), + # Only send not-none values to the API + json={key: value for key, value in payload.items() if value is not None}, + ) + hf_raise_for_status(r) + + def delete_collection_item( + self, + collection_slug: str, + item_object_id: str, + *, + missing_ok: bool = False, + token: Union[bool, str, None] = None, + ) -> None: + """Delete an item from a collection. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_object_id (`str`): + ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). + It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0]._id`. + missing_ok (`bool`, *optional*): + If `True`, do not raise an error if item doesn't exists. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Example: + + ```py + >>> from huggingface_hub import get_collection, delete_collection_item + + # Get collection first + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + + # Delete item based on its ID + >>> delete_collection_item( + ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", + ... item_object_id=collection.items[-1].item_object_id, + ... ) + ``` + """ + r = get_session().delete( + f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", + headers=self._build_hf_headers(token=token), + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if missing_ok and err.response.status_code == 404: + # Item already deleted and `missing_ok=True` + return + else: + raise + + ########################## + # Manage access requests # + ########################## + + @validate_hf_hub_args + def list_pending_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> List[AccessRequest]: + """ + Get pending access requests for a given gated repo. + + A pending request means the user has requested access to the repo but the request has not been processed yet. + If the approval mode is automatic, this list should be empty. Pending requests can be accepted or rejected + using [`accept_access_request`] and [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_pending_access_requests, accept_access_request + + # List pending requests + >>> requests = list_pending_access_requests("meta-llama/Llama-2-7b") + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem π€', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='pending', + fields=None, + ), + ... + ] + + # Accept Clem's request + >>> accept_access_request("meta-llama/Llama-2-7b", "clem") + ``` + """ + return self._list_access_requests(repo_id, "pending", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def list_accepted_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> List[AccessRequest]: + """ + Get accepted access requests for a given gated repo. + + An accepted request means the user has requested access to the repo and the request has been accepted. The user + can download any file of the repo. If the approval mode is automatic, this list should contains by default all + requests. Accepted requests can be cancelled or rejected at any time using [`cancel_access_request`] and + [`reject_access_request`]. A cancelled request will go back to the pending list while a rejected request will + go to the rejected list. In both cases, the user will lose access to the repo. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_accepted_access_requests + + >>> requests = list_accepted_access_requests("meta-llama/Llama-2-7b") + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem π€', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='accepted', + fields=None, + ), + ... + ] + ``` + """ + return self._list_access_requests(repo_id, "accepted", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def list_rejected_access_requests( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> List[AccessRequest]: + """ + Get rejected access requests for a given gated repo. + + A rejected request means the user has requested access to the repo and the request has been explicitly rejected + by a repo owner (either you or another user from your organization). The user cannot download any file of the + repo. Rejected requests can be accepted or cancelled at any time using [`accept_access_request`] and + [`cancel_access_request`]. A cancelled request will go back to the pending list while an accepted request will + go to the accepted list. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to get access requests for. + repo_type (`str`, *optional*): + The type of the repo to get access requests for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `List[AccessRequest]`: A list of [`AccessRequest`] objects. Each time contains a `username`, `email`, + `status` and `timestamp` attribute. If the gated repo has a custom form, the `fields` attribute will + be populated with user's answers. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + + Example: + ```py + >>> from huggingface_hub import list_rejected_access_requests + + >>> requests = list_rejected_access_requests("meta-llama/Llama-2-7b") + >>> len(requests) + 411 + >>> requests[0] + [ + AccessRequest( + username='clem', + fullname='Clem π€', + email='***', + timestamp=datetime.datetime(2023, 11, 23, 18, 4, 53, 828000, tzinfo=datetime.timezone.utc), + status='rejected', + fields=None, + ), + ... + ] + ``` + """ + return self._list_access_requests(repo_id, "rejected", repo_type=repo_type, token=token) + + def _list_access_requests( + self, + repo_id: str, + status: Literal["accepted", "rejected", "pending"], + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> List[AccessRequest]: + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + response = get_session().get( + f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/{status}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + return [ + AccessRequest( + username=request["user"]["user"], + fullname=request["user"]["fullname"], + email=request["user"].get("email"), + status=request["status"], + timestamp=parse_datetime(request["timestamp"]), + fields=request.get("fields"), # only if custom fields in form + ) + for request in response.json() + ] + + @validate_hf_hub_args + def cancel_access_request( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Cancel an access request from a user for a given gated repo. + + A cancelled request will go back to the pending list and the user will lose access to the repo. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to cancel access request for. + user (`str`): + The username of the user which access request should be cancelled. + repo_type (`str`, *optional*): + The type of the repo to cancel access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user does not exist on the Hub. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request cannot be found. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request is already in the pending list. + """ + self._handle_access_request(repo_id, user, "pending", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def accept_access_request( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Accept an access request from a user for a given gated repo. + + Once the request is accepted, the user will be able to download any file of the repo and access the community + tab. If the approval mode is automatic, you don't have to accept requests manually. An accepted request can be + cancelled or rejected at any time using [`cancel_access_request`] and [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to accept access request for. + user (`str`): + The username of the user which access request should be accepted. + repo_type (`str`, *optional*): + The type of the repo to accept access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user does not exist on the Hub. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request cannot be found. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request is already in the accepted list. + """ + self._handle_access_request(repo_id, user, "accepted", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def reject_access_request( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Reject an access request from a user for a given gated repo. + + A rejected request will go to the rejected list. The user cannot download any file of the repo. Rejected + requests can be accepted or cancelled at any time using [`accept_access_request`] and [`cancel_access_request`]. + A cancelled request will go back to the pending list while an accepted request will go to the accepted list. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to reject access request for. + user (`str`): + The username of the user which access request should be rejected. + repo_type (`str`, *optional*): + The type of the repo to reject access request for. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user does not exist on the Hub. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request cannot be found. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user access request is already in the rejected list. + """ + self._handle_access_request(repo_id, user, "rejected", repo_type=repo_type, token=token) + + @validate_hf_hub_args + def _handle_access_request( + self, + repo_id: str, + user: str, + status: Literal["accepted", "rejected", "pending"], + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> None: + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + response = get_session().post( + f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/handle", + headers=self._build_hf_headers(token=token), + json={"user": user, "status": status}, + ) + hf_raise_for_status(response) + + @validate_hf_hub_args + def grant_access( + self, repo_id: str, user: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Grant access to a user for a given gated repo. + + Granting access don't require for the user to send an access request by themselves. The user is automatically + added to the accepted list meaning they can download the files You can revoke the granted access at any time + using [`cancel_access_request`] or [`reject_access_request`]. + + For more info about gated repos, see https://huggingface.co/docs/hub/models-gated. + + Args: + repo_id (`str`): + The id of the repo to grant access to. + user (`str`): + The username of the user to grant access. + repo_type (`str`, *optional*): + The type of the repo to grant access to. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the repo is not gated. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 400 if the user already has access to the repo. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 403 if you only have read-only access to the repo. This can be the case if you don't have `write` + or `admin` role in the organization the repo belongs to or if you passed a `read` token. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 if the user does not exist on the Hub. + """ + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + + response = get_session().post( + f"{constants.ENDPOINT}/api/models/{repo_id}/user-access-request/grant", + headers=self._build_hf_headers(token=token), + json={"user": user}, + ) + hf_raise_for_status(response) + return response.json() + + ################### + # Manage webhooks # + ################### + + @validate_hf_hub_args + def get_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: + """Get a webhook by its id. + + Args: + webhook_id (`str`): + The unique identifier of the webhook to get. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the webhook. + + Example: + ```python + >>> from huggingface_hub import get_webhook + >>> webhook = get_webhook("654bbbc16f2ec14d77f109cc") + >>> print(webhook) + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + secret="my-secret", + domains=["repo", "discussion"], + disabled=False, + ) + ``` + """ + response = get_session().get( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data["url"], + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def list_webhooks(self, *, token: Union[bool, str, None] = None) -> List[WebhookInfo]: + """List all configured webhooks. + + Args: + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `List[WebhookInfo]`: + List of webhook info objects. + + Example: + ```python + >>> from huggingface_hub import list_webhooks + >>> webhooks = list_webhooks() + >>> len(webhooks) + 2 + >>> webhooks[0] + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + secret="my-secret", + domains=["repo", "discussion"], + disabled=False, + ) + ``` + """ + response = get_session().get( + f"{constants.ENDPOINT}/api/settings/webhooks", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhooks_data = response.json() + + return [ + WebhookInfo( + id=webhook["id"], + url=webhook["url"], + watched=[WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook["watched"]], + domains=webhook["domains"], + secret=webhook.get("secret"), + disabled=webhook["disabled"], + ) + for webhook in webhooks_data + ] + + @validate_hf_hub_args + def create_webhook( + self, + *, + url: str, + watched: List[Union[Dict, WebhookWatchedItem]], + domains: Optional[List[constants.WEBHOOK_DOMAIN_T]] = None, + secret: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> WebhookInfo: + """Create a new webhook. + + Args: + url (`str`): + URL to send the payload to. + watched (`List[WebhookWatchedItem]`): + List of [`WebhookWatchedItem`] to be watched by the webhook. It can be users, orgs, models, datasets or spaces. + Watched items can also be provided as plain dictionaries. + domains (`List[Literal["repo", "discussion"]]`, optional): + List of domains to watch. It can be "repo", "discussion" or both. + secret (`str`, optional): + A secret to sign the payload with. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the newly created webhook. + + Example: + ```python + >>> from huggingface_hub import create_webhook + >>> payload = create_webhook( + ... watched=[{"type": "user", "name": "julien-c"}, {"type": "org", "name": "HuggingFaceH4"}], + ... url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + ... domains=["repo", "discussion"], + ... secret="my-secret", + ... ) + >>> print(payload) + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo", "discussion"], + secret="my-secret", + disabled=False, + ) + ``` + """ + watched_dicts = [asdict(item) if isinstance(item, WebhookWatchedItem) else item for item in watched] + + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks", + json={"watched": watched_dicts, "url": url, "domains": domains, "secret": secret}, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data["url"], + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def update_webhook( + self, + webhook_id: str, + *, + url: Optional[str] = None, + watched: Optional[List[Union[Dict, WebhookWatchedItem]]] = None, + domains: Optional[List[constants.WEBHOOK_DOMAIN_T]] = None, + secret: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> WebhookInfo: + """Update an existing webhook. + + Args: + webhook_id (`str`): + The unique identifier of the webhook to be updated. + url (`str`, optional): + The URL to which the payload will be sent. + watched (`List[WebhookWatchedItem]`, optional): + List of items to watch. It can be users, orgs, models, datasets, or spaces. + Refer to [`WebhookWatchedItem`] for more details. Watched items can also be provided as plain dictionaries. + domains (`List[Literal["repo", "discussion"]]`, optional): + The domains to watch. This can include "repo", "discussion", or both. + secret (`str`, optional): + A secret to sign the payload with, providing an additional layer of security. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the updated webhook. + + Example: + ```python + >>> from huggingface_hub import update_webhook + >>> updated_payload = update_webhook( + ... webhook_id="654bbbc16f2ec14d77f109cc", + ... url="https://new.webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + ... watched=[{"type": "user", "name": "julien-c"}, {"type": "org", "name": "HuggingFaceH4"}], + ... domains=["repo"], + ... secret="my-secret", + ... ) + >>> print(updated_payload) + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + url="https://new.webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo"], + secret="my-secret", + disabled=False, + ``` + """ + if watched is None: + watched = [] + watched_dicts = [asdict(item) if isinstance(item, WebhookWatchedItem) else item for item in watched] + + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", + json={"watched": watched_dicts, "url": url, "domains": domains, "secret": secret}, + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data["url"], + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def enable_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: + """Enable a webhook (makes it "active"). + + Args: + webhook_id (`str`): + The unique identifier of the webhook to enable. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the enabled webhook. + + Example: + ```python + >>> from huggingface_hub import enable_webhook + >>> enabled_webhook = enable_webhook("654bbbc16f2ec14d77f109cc") + >>> enabled_webhook + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo", "discussion"], + secret="my-secret", + disabled=False, + ) + ``` + """ + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}/enable", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data["url"], + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def disable_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> WebhookInfo: + """Disable a webhook (makes it "disabled"). + + Args: + webhook_id (`str`): + The unique identifier of the webhook to disable. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`WebhookInfo`]: + Info about the disabled webhook. + + Example: + ```python + >>> from huggingface_hub import disable_webhook + >>> disabled_webhook = disable_webhook("654bbbc16f2ec14d77f109cc") + >>> disabled_webhook + WebhookInfo( + id="654bbbc16f2ec14d77f109cc", + url="https://webhook.site/a2176e82-5720-43ee-9e06-f91cb4c91548", + watched=[WebhookWatchedItem(type="user", name="julien-c"), WebhookWatchedItem(type="org", name="HuggingFaceH4")], + domains=["repo", "discussion"], + secret="my-secret", + disabled=True, + ) + ``` + """ + response = get_session().post( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}/disable", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + webhook_data = response.json()["webhook"] + + watched_items = [WebhookWatchedItem(type=item["type"], name=item["name"]) for item in webhook_data["watched"]] + + webhook = WebhookInfo( + id=webhook_data["id"], + url=webhook_data["url"], + watched=watched_items, + domains=webhook_data["domains"], + secret=webhook_data.get("secret"), + disabled=webhook_data["disabled"], + ) + + return webhook + + @validate_hf_hub_args + def delete_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) -> None: + """Delete a webhook. + + Args: + webhook_id (`str`): + The unique identifier of the webhook to delete. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved token, which is the recommended + method for authentication (see https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `None` + + Example: + ```python + >>> from huggingface_hub import delete_webhook + >>> delete_webhook("654bbbc16f2ec14d77f109cc") + ``` + """ + response = get_session().delete( + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", + headers=self._build_hf_headers(token=token), + ) + hf_raise_for_status(response) + + ############# + # Internals # + ############# + + def _build_hf_headers( + self, + token: Union[bool, str, None] = None, + is_write_action: bool = False, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + ) -> Dict[str, str]: + """ + Alias for [`build_hf_headers`] that uses the token from [`HfApi`] client + when `token` is not provided. + """ + if token is None: + # Cannot do `token = token or self.token` as token can be `False`. + token = self.token + return build_hf_headers( + token=token, + is_write_action=is_write_action, + library_name=library_name or self.library_name, + library_version=library_version or self.library_version, + user_agent=user_agent or self.user_agent, + headers=self.headers, + ) + + def _prepare_folder_deletions( + self, + repo_id: str, + repo_type: Optional[str], + revision: Optional[str], + path_in_repo: str, + delete_patterns: Optional[Union[List[str], str]], + token: Union[bool, str, None] = None, + ) -> List[CommitOperationDelete]: + """Generate the list of Delete operations for a commit to delete files from a repo. + + List remote files and match them against the `delete_patterns` constraints. Returns a list of [`CommitOperationDelete`] + with the matching items. + + Note: `.gitattributes` file is essential to make a repo work properly on the Hub. This file will always be + kept even if it matches the `delete_patterns` constraints. + """ + if delete_patterns is None: + # If no delete patterns, no need to list and filter remote files + return [] + + # List remote files + filenames = self.list_repo_files(repo_id=repo_id, revision=revision, repo_type=repo_type, token=token) + + # Compute relative path in repo + if path_in_repo and path_in_repo not in (".", "./"): + path_in_repo = path_in_repo.strip("/") + "/" # harmonize + relpath_to_abspath = { + file[len(path_in_repo) :]: file for file in filenames if file.startswith(path_in_repo) + } + else: + relpath_to_abspath = {file: file for file in filenames} + + # Apply filter on relative paths and return + return [ + CommitOperationDelete(path_in_repo=relpath_to_abspath[relpath], is_folder=False) + for relpath in filter_repo_objects(relpath_to_abspath.keys(), allow_patterns=delete_patterns) + if relpath_to_abspath[relpath] != ".gitattributes" + ] + + def _prepare_upload_folder_additions( + self, + folder_path: Union[str, Path], + path_in_repo: str, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + repo_type: Optional[str] = None, + token: Union[bool, str, None] = None, + ) -> List[CommitOperationAdd]: + """Generate the list of Add operations for a commit to upload a folder. + + Files not matching the `allow_patterns` (allowlist) and `ignore_patterns` (denylist) + constraints are discarded. + """ + + folder_path = Path(folder_path).expanduser().resolve() + if not folder_path.is_dir(): + raise ValueError(f"Provided path: '{folder_path}' is not a directory") + + # List files from folder + relpath_to_abspath = { + path.relative_to(folder_path).as_posix(): path + for path in sorted(folder_path.glob("**/*")) # sorted to be deterministic + if path.is_file() + } + + # Filter files + # Patterns are applied on the path relative to `folder_path`. `path_in_repo` is prefixed after the filtering. + filtered_repo_objects = list( + filter_repo_objects( + relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns + ) + ) + + prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else "" + + # If updating a README.md file, make sure the metadata format is valid + # It's better to fail early than to fail after all the files have been hashed. + if "README.md" in filtered_repo_objects: + self._validate_yaml( + content=relpath_to_abspath["README.md"].read_text(), + repo_type=repo_type, + token=token, + ) + if len(filtered_repo_objects) > 30: + logger.info( + "It seems you are trying to upload a large folder at once. This might take some time and then fail if " + "the folder is too large. For such cases, it is recommended to upload in smaller batches or to use " + "`HfApi().upload_large_folder(...)`/`huggingface-cli upload-large-folder` instead. For more details, " + "check out https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#upload-a-large-folder." + ) + + logger.info(f"Start hashing {len(filtered_repo_objects)} files.") + operations = [ + CommitOperationAdd( + path_or_fileobj=relpath_to_abspath[relpath], # absolute path on disk + path_in_repo=prefix + relpath, # "absolute" path in repo + ) + for relpath in filtered_repo_objects + ] + logger.info(f"Finished hashing {len(filtered_repo_objects)} files.") + return operations + + def _validate_yaml(self, content: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None): + """ + Validate YAML from `README.md`, used before file hashing and upload. + + Args: + content (`str`): + Content of `README.md` to validate. + repo_type (`str`, *optional*): + The type of the repo to grant access to. Must be one of `model`, `dataset` or `space`. + Defaults to `model`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Raises: + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if YAML is invalid + """ + repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL + headers = self._build_hf_headers(token=token) + + response = get_session().post( + f"{self.endpoint}/api/validate-yaml", + json={"content": content, "repoType": repo_type}, + headers=headers, + ) + # Handle warnings (example: empty metadata) + response_content = response.json() + message = "\n".join([f"- {warning.get('message')}" for warning in response_content.get("warnings", [])]) + if message: + warnings.warn(f"Warnings while validating metadata in README.md:\n{message}") + + # Raise on errors + try: + hf_raise_for_status(response) + except BadRequestError as e: + errors = response_content.get("errors", []) + message = "\n".join([f"- {error.get('message')}" for error in errors]) + raise ValueError(f"Invalid metadata in README.md.\n{message}") from e + + def get_user_overview(self, username: str, token: Union[bool, str, None] = None) -> User: + """ + Get an overview of a user on the Hub. + + Args: + username (`str`): + Username of the user to get an overview of. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `User`: A [`User`] object with the user's overview. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 If the user does not exist on the Hub. + """ + r = get_session().get( + f"{constants.ENDPOINT}/api/users/{username}/overview", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return User(**r.json()) + + def list_organization_members(self, organization: str, token: Union[bool, str, None] = None) -> Iterable[User]: + """ + List of members of an organization on the Hub. + + Args: + organization (`str`): + Name of the organization to get the members of. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[User]`: A list of [`User`] objects with the members of the organization. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 If the organization does not exist on the Hub. + + """ + for member in paginate( + path=f"{constants.ENDPOINT}/api/organizations/{organization}/members", + params={}, + headers=self._build_hf_headers(token=token), + ): + yield User(**member) + + def list_user_followers(self, username: str, token: Union[bool, str, None] = None) -> Iterable[User]: + """ + Get the list of followers of a user on the Hub. + + Args: + username (`str`): + Username of the user to get the followers of. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[User]`: A list of [`User`] objects with the followers of the user. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 If the user does not exist on the Hub. + + """ + for follower in paginate( + path=f"{constants.ENDPOINT}/api/users/{username}/followers", + params={}, + headers=self._build_hf_headers(token=token), + ): + yield User(**follower) + + def list_user_following(self, username: str, token: Union[bool, str, None] = None) -> Iterable[User]: + """ + Get the list of users followed by a user on the Hub. + + Args: + username (`str`): + Username of the user to get the users followed by. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + `Iterable[User]`: A list of [`User`] objects with the users followed by the user. + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): + HTTP 404 If the user does not exist on the Hub. + + """ + for followed_user in paginate( + path=f"{constants.ENDPOINT}/api/users/{username}/following", + params={}, + headers=self._build_hf_headers(token=token), + ): + yield User(**followed_user) + + def auth_check( + self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None + ) -> None: + """ + Check if the provided user token has access to a specific repository on the Hugging Face Hub. + + This method verifies whether the user, authenticated via the provided token, has access to the specified + repository. If the repository is not found or if the user lacks the required permissions to access it, + the method raises an appropriate exception. + + Args: + repo_id (`str`): + The repository to check for access. Format should be `"user/repo_name"`. + Example: `"user/my-cool-model"`. + + repo_type (`str`, *optional*): + The type of the repository. Should be one of `"model"`, `"dataset"`, or `"space"`. + If not specified, the default is `"model"`. + + token `(Union[bool, str, None]`, *optional*): + A valid user access token. If not provided, the locally saved token will be used, which is the + recommended authentication method. Set to `False` to disable authentication. + Refer to: https://huggingface.co/docs/huggingface_hub/quick-start#authentication. + + Raises: + [`~utils.RepositoryNotFoundError`]: + Raised if the repository does not exist, is private, or the user does not have access. This can + occur if the `repo_id` or `repo_type` is incorrect or if the repository is private but the user + is not authenticated. + + [`~utils.GatedRepoError`]: + Raised if the repository exists but is gated and the user is not authorized to access it. + + Example: + Check if the user has access to a repository: + + ```python + >>> from huggingface_hub import auth_check + >>> from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + + try: + auth_check("user/my-cool-model") + except GatedRepoError: + # Handle gated repository error + print("You do not have permission to access this gated repository.") + except RepositoryNotFoundError: + # Handle repository not found error + print("The repository was not found or you do not have access.") + ``` + + In this example: + - If the user has access, the method completes successfully. + - If the repository is gated or does not exist, appropriate exceptions are raised, allowing the user + to handle them accordingly. + """ + headers = self._build_hf_headers(token=token) + if repo_type is None: + repo_type = constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/auth-check" + r = get_session().get(path, headers=headers) + hf_raise_for_status(r) + + +def _parse_revision_from_pr_url(pr_url: str) -> str: + """Safely parse revision number from a PR url. + + Example: + ```py + >>> _parse_revision_from_pr_url("https://huggingface.co/bigscience/bloom/discussions/2") + "refs/pr/2" + ``` + """ + re_match = re.match(_REGEX_DISCUSSION_URL, pr_url) + if re_match is None: + raise RuntimeError(f"Unexpected response from the hub, expected a Pull Request URL but got: '{pr_url}'") + return f"refs/pr/{re_match[1]}" + + +api = HfApi() + +whoami = api.whoami +auth_check = api.auth_check +get_token_permission = api.get_token_permission + +list_models = api.list_models +model_info = api.model_info + +list_datasets = api.list_datasets +dataset_info = api.dataset_info + +list_spaces = api.list_spaces +space_info = api.space_info + +repo_exists = api.repo_exists +revision_exists = api.revision_exists +file_exists = api.file_exists +repo_info = api.repo_info +list_repo_files = api.list_repo_files +list_repo_refs = api.list_repo_refs +list_repo_commits = api.list_repo_commits +list_repo_tree = api.list_repo_tree +get_paths_info = api.get_paths_info +list_metrics = api.list_metrics + +get_model_tags = api.get_model_tags +get_dataset_tags = api.get_dataset_tags + +create_commit = api.create_commit +create_repo = api.create_repo +delete_repo = api.delete_repo +update_repo_visibility = api.update_repo_visibility +update_repo_settings = api.update_repo_settings +super_squash_history = api.super_squash_history +move_repo = api.move_repo +upload_file = api.upload_file +upload_folder = api.upload_folder +delete_file = api.delete_file +delete_folder = api.delete_folder +delete_files = api.delete_files +create_commits_on_pr = api.create_commits_on_pr +upload_large_folder = api.upload_large_folder +preupload_lfs_files = api.preupload_lfs_files +create_branch = api.create_branch +delete_branch = api.delete_branch +create_tag = api.create_tag +delete_tag = api.delete_tag +get_full_repo_name = api.get_full_repo_name + +# Safetensors helpers +get_safetensors_metadata = api.get_safetensors_metadata +parse_safetensors_file_metadata = api.parse_safetensors_file_metadata + +# Background jobs +run_as_future = api.run_as_future + +# Activity API +list_liked_repos = api.list_liked_repos +list_repo_likers = api.list_repo_likers +like = api.like +unlike = api.unlike + +# Community API +get_discussion_details = api.get_discussion_details +get_repo_discussions = api.get_repo_discussions +create_discussion = api.create_discussion +create_pull_request = api.create_pull_request +change_discussion_status = api.change_discussion_status +comment_discussion = api.comment_discussion +edit_discussion_comment = api.edit_discussion_comment +rename_discussion = api.rename_discussion +merge_pull_request = api.merge_pull_request + +# Space API +add_space_secret = api.add_space_secret +delete_space_secret = api.delete_space_secret +get_space_variables = api.get_space_variables +add_space_variable = api.add_space_variable +delete_space_variable = api.delete_space_variable +get_space_runtime = api.get_space_runtime +request_space_hardware = api.request_space_hardware +set_space_sleep_time = api.set_space_sleep_time +pause_space = api.pause_space +restart_space = api.restart_space +duplicate_space = api.duplicate_space +request_space_storage = api.request_space_storage +delete_space_storage = api.delete_space_storage + +# Inference Endpoint API +list_inference_endpoints = api.list_inference_endpoints +create_inference_endpoint = api.create_inference_endpoint +get_inference_endpoint = api.get_inference_endpoint +update_inference_endpoint = api.update_inference_endpoint +delete_inference_endpoint = api.delete_inference_endpoint +pause_inference_endpoint = api.pause_inference_endpoint +resume_inference_endpoint = api.resume_inference_endpoint +scale_to_zero_inference_endpoint = api.scale_to_zero_inference_endpoint + +# Collections API +get_collection = api.get_collection +list_collections = api.list_collections +create_collection = api.create_collection +update_collection_metadata = api.update_collection_metadata +delete_collection = api.delete_collection +add_collection_item = api.add_collection_item +update_collection_item = api.update_collection_item +delete_collection_item = api.delete_collection_item +delete_collection_item = api.delete_collection_item + +# Access requests API +list_pending_access_requests = api.list_pending_access_requests +list_accepted_access_requests = api.list_accepted_access_requests +list_rejected_access_requests = api.list_rejected_access_requests +cancel_access_request = api.cancel_access_request +accept_access_request = api.accept_access_request +reject_access_request = api.reject_access_request +grant_access = api.grant_access + +# Webhooks API +create_webhook = api.create_webhook +disable_webhook = api.disable_webhook +delete_webhook = api.delete_webhook +enable_webhook = api.enable_webhook +get_webhook = api.get_webhook +list_webhooks = api.list_webhooks +update_webhook = api.update_webhook + + +# User API +get_user_overview = api.get_user_overview +list_organization_members = api.list_organization_members +list_user_followers = api.list_user_followers +list_user_following = api.list_user_following diff --git a/huggingface_hub/hf_file_system.py b/huggingface_hub/hf_file_system.py new file mode 100644 index 0000000000000000000000000000000000000000..a831b6c92952ce845891e3a5a90072e3fd161348 --- /dev/null +++ b/huggingface_hub/hf_file_system.py @@ -0,0 +1,901 @@ +import inspect +import os +import re +import tempfile +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from itertools import chain +from pathlib import Path +from typing import Any, Dict, List, NoReturn, Optional, Tuple, Union +from urllib.parse import quote, unquote + +import fsspec +from fsspec.callbacks import _DEFAULT_CALLBACK, NoOpCallback, TqdmCallback +from fsspec.utils import isfilelike +from requests import Response + +from . import constants +from ._commit_api import CommitOperationCopy, CommitOperationDelete +from .errors import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError +from .file_download import hf_hub_url, http_get +from .hf_api import HfApi, LastCommitInfo, RepoFile +from .utils import ( + HFValidationError, + hf_raise_for_status, + http_backoff, +) + + +# Regex used to match special revisions with "/" in them (see #1710) +SPECIAL_REFS_REVISION_REGEX = re.compile( + r""" + (^refs\/convert\/\w+) # `refs/convert/parquet` revisions + | + (^refs\/pr\/\d+) # PR revisions + """, + re.VERBOSE, +) + + +@dataclass +class HfFileSystemResolvedPath: + """Data structure containing information about a resolved Hugging Face file system path.""" + + repo_type: str + repo_id: str + revision: str + path_in_repo: str + # The part placed after '@' in the initial path. It can be a quoted or unquoted refs revision. + # Used to reconstruct the unresolved path to return to the user. + _raw_revision: Optional[str] = field(default=None, repr=False) + + def unresolve(self) -> str: + repo_path = constants.REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id + if self._raw_revision: + return f"{repo_path}@{self._raw_revision}/{self.path_in_repo}".rstrip("/") + elif self.revision != constants.DEFAULT_REVISION: + return f"{repo_path}@{safe_revision(self.revision)}/{self.path_in_repo}".rstrip("/") + else: + return f"{repo_path}/{self.path_in_repo}".rstrip("/") + + +class HfFileSystem(fsspec.AbstractFileSystem): + """ + Access a remote Hugging Face Hub repository as if were a local file system. + + Args: + token (`str` or `bool`, *optional*): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Usage: + + ```python + >>> from huggingface_hub import HfFileSystem + + >>> fs = HfFileSystem() + + >>> # List files + >>> fs.glob("my-username/my-model/*.bin") + ['my-username/my-model/pytorch_model.bin'] + >>> fs.ls("datasets/my-username/my-dataset", detail=False) + ['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json'] + + >>> # Read/write files + >>> with fs.open("my-username/my-model/pytorch_model.bin") as f: + ... data = f.read() + >>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f: + ... f.write(data) + ``` + """ + + root_marker = "" + protocol = "hf" + + def __init__( + self, + *args, + endpoint: Optional[str] = None, + token: Union[bool, str, None] = None, + **storage_options, + ): + super().__init__(*args, **storage_options) + self.endpoint = endpoint or constants.ENDPOINT + self.token = token + self._api = HfApi(endpoint=endpoint, token=token) + # Maps (repo_type, repo_id, revision) to a 2-tuple with: + # * the 1st element indicating whether the repositoy and the revision exist + # * the 2nd element being the exception raised if the repository or revision doesn't exist + self._repo_and_revision_exists_cache: Dict[ + Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]] + ] = {} + + def _repo_and_revision_exist( + self, repo_type: str, repo_id: str, revision: Optional[str] + ) -> Tuple[bool, Optional[Exception]]: + if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache: + try: + self._api.repo_info( + repo_id, revision=revision, repo_type=repo_type, timeout=constants.HF_HUB_ETAG_TIMEOUT + ) + except (RepositoryNotFoundError, HFValidationError) as e: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e + except RevisionNotFoundError as e: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None + else: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None + return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] + + def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath: + def _align_revision_in_path_with_revision( + revision_in_path: Optional[str], revision: Optional[str] + ) -> Optional[str]: + if revision is not None: + if revision_in_path is not None and revision_in_path != revision: + raise ValueError( + f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")' + " are not the same." + ) + else: + revision = revision_in_path + return revision + + path = self._strip_protocol(path) + if not path: + # can't list repositories at root + raise NotImplementedError("Access to repositories lists is not implemented.") + elif path.split("/")[0] + "/" in constants.REPO_TYPES_URL_PREFIXES.values(): + if "/" not in path: + # can't list repositories at the repository type level + raise NotImplementedError("Access to repositories lists is not implemented.") + repo_type, path = path.split("/", 1) + repo_type = constants.REPO_TYPES_MAPPING[repo_type] + else: + repo_type = constants.REPO_TYPE_MODEL + if path.count("/") > 0: + if "@" in path: + repo_id, revision_in_path = path.split("@", 1) + if "/" in revision_in_path: + match = SPECIAL_REFS_REVISION_REGEX.search(revision_in_path) + if match is not None and revision in (None, match.group()): + # Handle `refs/convert/parquet` and PR revisions separately + path_in_repo = SPECIAL_REFS_REVISION_REGEX.sub("", revision_in_path).lstrip("/") + revision_in_path = match.group() + else: + revision_in_path, path_in_repo = revision_in_path.split("/", 1) + else: + path_in_repo = "" + revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision) + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + _raise_file_not_found(path, err) + else: + revision_in_path = None + repo_id_with_namespace = "/".join(path.split("/")[:2]) + path_in_repo_with_namespace = "/".join(path.split("/")[2:]) + repo_id_without_namespace = path.split("/")[0] + path_in_repo_without_namespace = "/".join(path.split("/")[1:]) + repo_id = repo_id_with_namespace + path_in_repo = path_in_repo_with_namespace + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + if isinstance(err, (RepositoryNotFoundError, HFValidationError)): + repo_id = repo_id_without_namespace + path_in_repo = path_in_repo_without_namespace + repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + _raise_file_not_found(path, err) + else: + _raise_file_not_found(path, err) + else: + repo_id = path + path_in_repo = "" + if "@" in path: + repo_id, revision_in_path = path.split("@", 1) + revision = _align_revision_in_path_with_revision(unquote(revision_in_path), revision) + else: + revision_in_path = None + repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + raise NotImplementedError("Access to repositories lists is not implemented.") + + revision = revision if revision is not None else constants.DEFAULT_REVISION + return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo, _raw_revision=revision_in_path) + + def invalidate_cache(self, path: Optional[str] = None) -> None: + if not path: + self.dircache.clear() + self._repo_and_revision_exists_cache.clear() + else: + path = self.resolve_path(path).unresolve() + while path: + self.dircache.pop(path, None) + path = self._parent(path) + + def _open( + self, + path: str, + mode: str = "rb", + revision: Optional[str] = None, + block_size: Optional[int] = None, + **kwargs, + ) -> "HfFileSystemFile": + if "a" in mode: + raise NotImplementedError("Appending to remote files is not yet supported.") + if block_size == 0: + return HfFileSystemStreamFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs) + else: + return HfFileSystemFile(self, path, mode=mode, revision=revision, block_size=block_size, **kwargs) + + def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None: + resolved_path = self.resolve_path(path, revision=revision) + self._api.delete_file( + path_in_repo=resolved_path.path_in_repo, + repo_id=resolved_path.repo_id, + token=self.token, + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + commit_message=kwargs.get("commit_message"), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path.unresolve()) + + def rm( + self, + path: str, + recursive: bool = False, + maxdepth: Optional[int] = None, + revision: Optional[str] = None, + **kwargs, + ) -> None: + resolved_path = self.resolve_path(path, revision=revision) + paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=revision) + paths_in_repo = [self.resolve_path(path).path_in_repo for path in paths if not self.isdir(path)] + operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo] + commit_message = f"Delete {path} " + commit_message += "recursively " if recursive else "" + commit_message += f"up to depth {maxdepth} " if maxdepth is not None else "" + # TODO: use `commit_description` to list all the deleted paths? + self._api.create_commit( + repo_id=resolved_path.repo_id, + repo_type=resolved_path.repo_type, + token=self.token, + operations=operations, + revision=resolved_path.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path.unresolve()) + + def ls( + self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs + ) -> List[Union[str, Dict[str, Any]]]: + """List the contents of a directory.""" + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + kwargs = {"expand_info": detail, **kwargs} + try: + out = self._ls_tree(path, refresh=refresh, revision=revision, **kwargs) + except EntryNotFoundError: + # Path could be a file + if not resolved_path.path_in_repo: + _raise_file_not_found(path, None) + out = self._ls_tree(self._parent(path), refresh=refresh, revision=revision, **kwargs) + out = [o for o in out if o["name"] == path] + if len(out) == 0: + _raise_file_not_found(path, None) + return out if detail else [o["name"] for o in out] + + def _ls_tree( + self, + path: str, + recursive: bool = False, + refresh: bool = False, + revision: Optional[str] = None, + expand_info: bool = True, + ): + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + root_path = HfFileSystemResolvedPath( + resolved_path.repo_type, + resolved_path.repo_id, + resolved_path.revision, + path_in_repo="", + _raw_revision=resolved_path._raw_revision, + ).unresolve() + + out = [] + if path in self.dircache and not refresh: + cached_path_infos = self.dircache[path] + out.extend(cached_path_infos) + dirs_not_in_dircache = [] + if recursive: + # Use BFS to traverse the cache and build the "recursive "output + # (The Hub uses a so-called "tree first" strategy for the tree endpoint but we sort the output to follow the spec so the result is (eventually) the same) + dirs_to_visit = deque( + [path_info for path_info in cached_path_infos if path_info["type"] == "directory"] + ) + while dirs_to_visit: + dir_info = dirs_to_visit.popleft() + if dir_info["name"] not in self.dircache: + dirs_not_in_dircache.append(dir_info["name"]) + else: + cached_path_infos = self.dircache[dir_info["name"]] + out.extend(cached_path_infos) + dirs_to_visit.extend( + [path_info for path_info in cached_path_infos if path_info["type"] == "directory"] + ) + + dirs_not_expanded = [] + if expand_info: + # Check if there are directories with non-expanded entries + dirs_not_expanded = [self._parent(o["name"]) for o in out if o["last_commit"] is None] + + if (recursive and dirs_not_in_dircache) or (expand_info and dirs_not_expanded): + # If the dircache is incomplete, find the common path of the missing and non-expanded entries + # and extend the output with the result of `_ls_tree(common_path, recursive=True)` + common_prefix = os.path.commonprefix(dirs_not_in_dircache + dirs_not_expanded) + # Get the parent directory if the common prefix itself is not a directory + common_path = ( + common_prefix.rstrip("/") + if common_prefix.endswith("/") + or common_prefix == root_path + or common_prefix in chain(dirs_not_in_dircache, dirs_not_expanded) + else self._parent(common_prefix) + ) + out = [o for o in out if not o["name"].startswith(common_path + "/")] + for cached_path in self.dircache: + if cached_path.startswith(common_path + "/"): + self.dircache.pop(cached_path, None) + self.dircache.pop(common_path, None) + out.extend( + self._ls_tree( + common_path, + recursive=recursive, + refresh=True, + revision=revision, + expand_info=expand_info, + ) + ) + else: + tree = self._api.list_repo_tree( + resolved_path.repo_id, + resolved_path.path_in_repo, + recursive=recursive, + expand=expand_info, + revision=resolved_path.revision, + repo_type=resolved_path.repo_type, + ) + for path_info in tree: + if isinstance(path_info, RepoFile): + cache_path_info = { + "name": root_path + "/" + path_info.path, + "size": path_info.size, + "type": "file", + "blob_id": path_info.blob_id, + "lfs": path_info.lfs, + "last_commit": path_info.last_commit, + "security": path_info.security, + } + else: + cache_path_info = { + "name": root_path + "/" + path_info.path, + "size": 0, + "type": "directory", + "tree_id": path_info.tree_id, + "last_commit": path_info.last_commit, + } + parent_path = self._parent(cache_path_info["name"]) + self.dircache.setdefault(parent_path, []).append(cache_path_info) + out.append(cache_path_info) + return out + + def walk(self, path, *args, **kwargs): + # Set expand_info=False by default to get a x10 speed boost + kwargs = {"expand_info": kwargs.get("detail", False), **kwargs} + path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve() + yield from super().walk(path, *args, **kwargs) + + def glob(self, path, **kwargs): + # Set expand_info=False by default to get a x10 speed boost + kwargs = {"expand_info": kwargs.get("detail", False), **kwargs} + path = self.resolve_path(path, revision=kwargs.get("revision")).unresolve() + return super().glob(path, **kwargs) + + def find( + self, + path: str, + maxdepth: Optional[int] = None, + withdirs: bool = False, + detail: bool = False, + refresh: bool = False, + revision: Optional[str] = None, + **kwargs, + ) -> Union[List[str], Dict[str, Dict[str, Any]]]: + if maxdepth: + return super().find( + path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, refresh=refresh, revision=revision, **kwargs + ) + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + kwargs = {"expand_info": detail, **kwargs} + try: + out = self._ls_tree(path, recursive=True, refresh=refresh, revision=resolved_path.revision, **kwargs) + except EntryNotFoundError: + # Path could be a file + if self.info(path, revision=revision, **kwargs)["type"] == "file": + out = {path: {}} + else: + out = {} + else: + if not withdirs: + out = [o for o in out if o["type"] != "directory"] + else: + # If `withdirs=True`, include the directory itself to be consistent with the spec + path_info = self.info(path, revision=resolved_path.revision, **kwargs) + out = [path_info] + out if path_info["type"] == "directory" else out + out = {o["name"]: o for o in out} + names = sorted(out) + if not detail: + return names + else: + return {name: out[name] for name in names} + + def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None: + resolved_path1 = self.resolve_path(path1, revision=revision) + resolved_path2 = self.resolve_path(path2, revision=revision) + + same_repo = ( + resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id + ) + + if same_repo: + commit_message = f"Copy {path1} to {path2}" + self._api.create_commit( + repo_id=resolved_path1.repo_id, + repo_type=resolved_path1.repo_type, + revision=resolved_path2.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description", ""), + operations=[ + CommitOperationCopy( + src_path_in_repo=resolved_path1.path_in_repo, + path_in_repo=resolved_path2.path_in_repo, + src_revision=resolved_path1.revision, + ) + ], + ) + else: + with self.open(path1, "rb", revision=resolved_path1.revision) as f: + content = f.read() + commit_message = f"Copy {path1} to {path2}" + self._api.upload_file( + path_or_fileobj=content, + path_in_repo=resolved_path2.path_in_repo, + repo_id=resolved_path2.repo_id, + token=self.token, + repo_type=resolved_path2.repo_type, + revision=resolved_path2.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path1.unresolve()) + self.invalidate_cache(path=resolved_path2.unresolve()) + + def modified(self, path: str, **kwargs) -> datetime: + info = self.info(path, **kwargs) + return info["last_commit"]["date"] + + def info(self, path: str, refresh: bool = False, revision: Optional[str] = None, **kwargs) -> Dict[str, Any]: + resolved_path = self.resolve_path(path, revision=revision) + path = resolved_path.unresolve() + expand_info = kwargs.get( + "expand_info", True + ) # don't expose it as a parameter in the public API to follow the spec + if not resolved_path.path_in_repo: + # Path is the root directory + out = { + "name": path, + "size": 0, + "type": "directory", + } + if expand_info: + last_commit = self._api.list_repo_commits( + resolved_path.repo_id, repo_type=resolved_path.repo_type, revision=resolved_path.revision + )[-1] + out = { + **out, + "tree_id": None, # TODO: tree_id of the root directory? + "last_commit": LastCommitInfo( + oid=last_commit.commit_id, title=last_commit.title, date=last_commit.created_at + ), + } + else: + out = None + parent_path = self._parent(path) + if not expand_info and parent_path not in self.dircache: + # Fill the cache with cheap call + self.ls(parent_path, expand_info=False) + if parent_path in self.dircache: + # Check if the path is in the cache + out1 = [o for o in self.dircache[parent_path] if o["name"] == path] + if not out1: + _raise_file_not_found(path, None) + out = out1[0] + if refresh or out is None or (expand_info and out and out["last_commit"] is None): + paths_info = self._api.get_paths_info( + resolved_path.repo_id, + resolved_path.path_in_repo, + expand=expand_info, + revision=resolved_path.revision, + repo_type=resolved_path.repo_type, + ) + if not paths_info: + _raise_file_not_found(path, None) + path_info = paths_info[0] + root_path = HfFileSystemResolvedPath( + resolved_path.repo_type, + resolved_path.repo_id, + resolved_path.revision, + path_in_repo="", + _raw_revision=resolved_path._raw_revision, + ).unresolve() + if isinstance(path_info, RepoFile): + out = { + "name": root_path + "/" + path_info.path, + "size": path_info.size, + "type": "file", + "blob_id": path_info.blob_id, + "lfs": path_info.lfs, + "last_commit": path_info.last_commit, + "security": path_info.security, + } + else: + out = { + "name": root_path + "/" + path_info.path, + "size": 0, + "type": "directory", + "tree_id": path_info.tree_id, + "last_commit": path_info.last_commit, + } + if not expand_info: + out = {k: out[k] for k in ["name", "size", "type"]} + assert out is not None + return out + + def exists(self, path, **kwargs): + """Is there a file at the given path""" + try: + self.info(path, **{**kwargs, "expand_info": False}) + return True + except: # noqa: E722 + # any exception allowed bar FileNotFoundError? + return False + + def isdir(self, path): + """Is this entry directory-like?""" + try: + return self.info(path, expand_info=False)["type"] == "directory" + except OSError: + return False + + def isfile(self, path): + """Is this entry file-like?""" + try: + return self.info(path, expand_info=False)["type"] == "file" + except: # noqa: E722 + return False + + def url(self, path: str) -> str: + """Get the HTTP URL of the given path""" + resolved_path = self.resolve_path(path) + url = hf_hub_url( + resolved_path.repo_id, + resolved_path.path_in_repo, + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + endpoint=self.endpoint, + ) + if self.isdir(path): + url = url.replace("/resolve/", "/tree/", 1) + return url + + def get_file(self, rpath, lpath, callback=_DEFAULT_CALLBACK, outfile=None, **kwargs) -> None: + """Copy single remote file to local.""" + revision = kwargs.get("revision") + unhandled_kwargs = set(kwargs.keys()) - {"revision"} + if not isinstance(callback, (NoOpCallback, TqdmCallback)) or len(unhandled_kwargs) > 0: + # for now, let's not handle custom callbacks + # and let's not handle custom kwargs + return super().get_file(rpath, lpath, callback=callback, outfile=outfile, **kwargs) + + # Taken from https://github.com/fsspec/filesystem_spec/blob/47b445ae4c284a82dd15e0287b1ffc410e8fc470/fsspec/spec.py#L883 + if isfilelike(lpath): + outfile = lpath + elif self.isdir(rpath): + os.makedirs(lpath, exist_ok=True) + return None + + if isinstance(lpath, (str, Path)): # otherwise, let's assume it's a file-like object + os.makedirs(os.path.dirname(lpath), exist_ok=True) + + # Open file if not already open + close_file = False + if outfile is None: + outfile = open(lpath, "wb") + close_file = True + initial_pos = outfile.tell() + + # Custom implementation of `get_file` to use `http_get`. + resolve_remote_path = self.resolve_path(rpath, revision=revision) + expected_size = self.info(rpath, revision=revision)["size"] + callback.set_size(expected_size) + try: + http_get( + url=hf_hub_url( + repo_id=resolve_remote_path.repo_id, + revision=resolve_remote_path.revision, + filename=resolve_remote_path.path_in_repo, + repo_type=resolve_remote_path.repo_type, + endpoint=self.endpoint, + ), + temp_file=outfile, + displayed_filename=rpath, + expected_size=expected_size, + resume_size=0, + headers=self._api._build_hf_headers(), + _tqdm_bar=callback.tqdm if isinstance(callback, TqdmCallback) else None, + ) + outfile.seek(initial_pos) + finally: + # Close file only if we opened it ourselves + if close_file: + outfile.close() + + @property + def transaction(self): + """A context within which files are committed together upon exit + + Requires the file class to implement `.commit()` and `.discard()` + for the normal and exception cases. + """ + # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L231 + # See https://github.com/huggingface/huggingface_hub/issues/1733 + raise NotImplementedError("Transactional commits are not supported.") + + def start_transaction(self): + """Begin write transaction for deferring files, non-context version""" + # Taken from https://github.com/fsspec/filesystem_spec/blob/3fbb6fee33b46cccb015607630843dea049d3243/fsspec/spec.py#L241 + # See https://github.com/huggingface/huggingface_hub/issues/1733 + raise NotImplementedError("Transactional commits are not supported.") + + +class HfFileSystemFile(fsspec.spec.AbstractBufferedFile): + def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs): + try: + self.resolved_path = fs.resolve_path(path, revision=revision) + except FileNotFoundError as e: + if "w" in kwargs.get("mode", ""): + raise FileNotFoundError( + f"{e}.\nMake sure the repository and revision exist before writing data." + ) from e + raise + # avoid an unnecessary .info() call with expensive expand_info=True to instantiate .details + if kwargs.get("mode", "rb") == "rb": + self.details = fs.info(self.resolved_path.unresolve(), expand_info=False) + super().__init__(fs, self.resolved_path.unresolve(), **kwargs) + self.fs: HfFileSystem + + def __del__(self): + if not hasattr(self, "resolved_path"): + # Means that the constructor failed. Nothing to do. + return + return super().__del__() + + def _fetch_range(self, start: int, end: int) -> bytes: + headers = { + "range": f"bytes={start}-{end - 1}", + **self.fs._api._build_hf_headers(), + } + url = hf_hub_url( + repo_id=self.resolved_path.repo_id, + revision=self.resolved_path.revision, + filename=self.resolved_path.path_in_repo, + repo_type=self.resolved_path.repo_type, + endpoint=self.fs.endpoint, + ) + r = http_backoff( + "GET", + url, + headers=headers, + retry_on_status_codes=(502, 503, 504), + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, + ) + hf_raise_for_status(r) + return r.content + + def _initiate_upload(self) -> None: + self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False) + + def _upload_chunk(self, final: bool = False) -> None: + self.buffer.seek(0) + block = self.buffer.read() + self.temp_file.write(block) + if final: + self.temp_file.close() + self.fs._api.upload_file( + path_or_fileobj=self.temp_file.name, + path_in_repo=self.resolved_path.path_in_repo, + repo_id=self.resolved_path.repo_id, + token=self.fs.token, + repo_type=self.resolved_path.repo_type, + revision=self.resolved_path.revision, + commit_message=self.kwargs.get("commit_message"), + commit_description=self.kwargs.get("commit_description"), + ) + os.remove(self.temp_file.name) + self.fs.invalidate_cache( + path=self.resolved_path.unresolve(), + ) + + def read(self, length=-1): + """Read remote file. + + If `length` is not provided or is -1, the entire file is downloaded and read. On POSIX systems and if + `hf_transfer` is not enabled, the file is loaded in memory directly. Otherwise, the file is downloaded to a + temporary file and read from there. + """ + if self.mode == "rb" and (length is None or length == -1) and self.loc == 0: + with self.fs.open(self.path, "rb", block_size=0) as f: # block_size=0 enables fast streaming + return f.read() + return super().read(length) + + def url(self) -> str: + return self.fs.url(self.path) + + +class HfFileSystemStreamFile(fsspec.spec.AbstractBufferedFile): + def __init__( + self, + fs: HfFileSystem, + path: str, + mode: str = "rb", + revision: Optional[str] = None, + block_size: int = 0, + cache_type: str = "none", + **kwargs, + ): + if block_size != 0: + raise ValueError(f"HfFileSystemStreamFile only supports block_size=0 but got {block_size}") + if cache_type != "none": + raise ValueError(f"HfFileSystemStreamFile only supports cache_type='none' but got {cache_type}") + if "w" in mode: + raise ValueError(f"HfFileSystemStreamFile only supports reading but got mode='{mode}'") + try: + self.resolved_path = fs.resolve_path(path, revision=revision) + except FileNotFoundError as e: + if "w" in kwargs.get("mode", ""): + raise FileNotFoundError( + f"{e}.\nMake sure the repository and revision exist before writing data." + ) from e + # avoid an unnecessary .info() call to instantiate .details + self.details = {"name": self.resolved_path.unresolve(), "size": None} + super().__init__( + fs, self.resolved_path.unresolve(), mode=mode, block_size=block_size, cache_type=cache_type, **kwargs + ) + self.response: Optional[Response] = None + self.fs: HfFileSystem + + def seek(self, loc: int, whence: int = 0): + if loc == 0 and whence == 1: + return + if loc == self.loc and whence == 0: + return + raise ValueError("Cannot seek streaming HF file") + + def read(self, length: int = -1): + read_args = (length,) if length >= 0 else () + if self.response is None or self.response.raw.isclosed(): + url = hf_hub_url( + repo_id=self.resolved_path.repo_id, + revision=self.resolved_path.revision, + filename=self.resolved_path.path_in_repo, + repo_type=self.resolved_path.repo_type, + endpoint=self.fs.endpoint, + ) + self.response = http_backoff( + "GET", + url, + headers=self.fs._api._build_hf_headers(), + retry_on_status_codes=(502, 503, 504), + stream=True, + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, + ) + hf_raise_for_status(self.response) + try: + out = self.response.raw.read(*read_args) + except Exception: + self.response.close() + + # Retry by recreating the connection + url = hf_hub_url( + repo_id=self.resolved_path.repo_id, + revision=self.resolved_path.revision, + filename=self.resolved_path.path_in_repo, + repo_type=self.resolved_path.repo_type, + endpoint=self.fs.endpoint, + ) + self.response = http_backoff( + "GET", + url, + headers={"Range": "bytes=%d-" % self.loc, **self.fs._api._build_hf_headers()}, + retry_on_status_codes=(502, 503, 504), + stream=True, + timeout=constants.HF_HUB_DOWNLOAD_TIMEOUT, + ) + hf_raise_for_status(self.response) + try: + out = self.response.raw.read(*read_args) + except Exception: + self.response.close() + raise + self.loc += len(out) + return out + + def url(self) -> str: + return self.fs.url(self.path) + + def __del__(self): + if not hasattr(self, "resolved_path"): + # Means that the constructor failed. Nothing to do. + return + return super().__del__() + + def __reduce__(self): + return reopen, (self.fs, self.path, self.mode, self.blocksize, self.cache.name) + + +def safe_revision(revision: str) -> str: + return revision if SPECIAL_REFS_REVISION_REGEX.match(revision) else safe_quote(revision) + + +def safe_quote(s: str) -> str: + return quote(s, safe="") + + +def _raise_file_not_found(path: str, err: Optional[Exception]) -> NoReturn: + msg = path + if isinstance(err, RepositoryNotFoundError): + msg = f"{path} (repository not found)" + elif isinstance(err, RevisionNotFoundError): + msg = f"{path} (revision not found)" + elif isinstance(err, HFValidationError): + msg = f"{path} (invalid repository id)" + raise FileNotFoundError(msg) from err + + +def reopen(fs: HfFileSystem, path: str, mode: str, block_size: int, cache_type: str): + return fs.open(path, mode=mode, block_size=block_size, cache_type=cache_type) + + +# Add docstrings to the methods of HfFileSystem from fsspec.AbstractFileSystem +for name, function in inspect.getmembers(HfFileSystem, predicate=inspect.isfunction): + parent = getattr(fsspec.AbstractFileSystem, name, None) + if parent is not None and parent.__doc__ is not None: + parent_doc = parent.__doc__ + parent_doc = parent_doc.replace("Parameters\n ----------\n", "Args:\n") + parent_doc = parent_doc.replace("Returns\n -------\n", "Return:\n") + function.__doc__ = ( + ( + "\n_Docstring taken from " + f"[fsspec documentation](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.{name})._" + ) + + "\n\n" + + parent_doc + ) diff --git a/huggingface_hub/hub_mixin.py b/huggingface_hub/hub_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..b330bccabe5f643e70c6b770ba2c0d1f457565d4 --- /dev/null +++ b/huggingface_hub/hub_mixin.py @@ -0,0 +1,852 @@ +import inspect +import json +import os +import warnings +from dataclasses import asdict, dataclass, is_dataclass +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from . import constants +from .errors import EntryNotFoundError, HfHubHTTPError +from .file_download import hf_hub_download +from .hf_api import HfApi +from .repocard import ModelCard, ModelCardData +from .utils import ( + SoftTemporaryDirectory, + is_jsonable, + is_safetensors_available, + is_simple_optional_type, + is_torch_available, + logging, + unwrap_simple_optional_type, + validate_hf_hub_args, +) + + +if TYPE_CHECKING: + pass + +if is_torch_available(): + import torch # type: ignore + +if is_safetensors_available(): + import packaging.version + import safetensors + from safetensors.torch import load_model as load_model_as_safetensor + from safetensors.torch import save_model as save_model_as_safetensor + + +logger = logging.get_logger(__name__) + +# Generic variable that is either ModelHubMixin or a subclass thereof +T = TypeVar("T", bound="ModelHubMixin") +# Generic variable to represent an args type +ARGS_T = TypeVar("ARGS_T") +ENCODER_T = Callable[[ARGS_T], Any] +DECODER_T = Callable[[Any], ARGS_T] +CODER_T = Tuple[ENCODER_T, DECODER_T] + + +DEFAULT_MODEL_CARD = """ +--- +# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/model-cards +{{ card_data }} +--- + +This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration: +- Library: {{ repo_url | default("[More Information Needed]", true) }} +- Docs: {{ docs_url | default("[More Information Needed]", true) }} +""" + + +@dataclass +class MixinInfo: + model_card_template: str + model_card_data: ModelCardData + repo_url: Optional[str] = None + docs_url: Optional[str] = None + + +class ModelHubMixin: + """ + A generic mixin to integrate ANY machine learning framework with the Hub. + + To integrate your framework, your model class must inherit from this class. Custom logic for saving/loading models + have to be overwritten in [`_from_pretrained`] and [`_save_pretrained`]. [`PyTorchModelHubMixin`] is a good example + of mixin integration with the Hub. Check out our [integration guide](../guides/integrations) for more instructions. + + When inheriting from [`ModelHubMixin`], you can define class-level attributes. These attributes are not passed to + `__init__` but to the class definition itself. This is useful to define metadata about the library integrating + [`ModelHubMixin`]. + + For more details on how to integrate the mixin with your library, checkout the [integration guide](../guides/integrations). + + Args: + repo_url (`str`, *optional*): + URL of the library repository. Used to generate model card. + docs_url (`str`, *optional*): + URL of the library documentation. Used to generate model card. + model_card_template (`str`, *optional*): + Template of the model card. Used to generate model card. Defaults to a generic template. + language (`str` or `List[str]`, *optional*): + Language supported by the library. Used to generate model card. + library_name (`str`, *optional*): + Name of the library integrating ModelHubMixin. Used to generate model card. + license (`str`, *optional*): + License of the library integrating ModelHubMixin. Used to generate model card. + E.g: "apache-2.0" + license_name (`str`, *optional*): + Name of the library integrating ModelHubMixin. Used to generate model card. + Only used if `license` is set to `other`. + E.g: "coqui-public-model-license". + license_link (`str`, *optional*): + URL to the license of the library integrating ModelHubMixin. Used to generate model card. + Only used if `license` is set to `other` and `license_name` is set. + E.g: "https://coqui.ai/cpml". + pipeline_tag (`str`, *optional*): + Tag of the pipeline. Used to generate model card. E.g. "text-classification". + tags (`List[str]`, *optional*): + Tags to be added to the model card. Used to generate model card. E.g. ["x-custom-tag", "arxiv:2304.12244"] + coders (`Dict[Type, Tuple[Callable, Callable]]`, *optional*): + Dictionary of custom types and their encoders/decoders. Used to encode/decode arguments that are not + jsonable by default. E.g dataclasses, argparse.Namespace, OmegaConf, etc. + + Example: + + ```python + >>> from huggingface_hub import ModelHubMixin + + # Inherit from ModelHubMixin + >>> class MyCustomModel( + ... ModelHubMixin, + ... library_name="my-library", + ... tags=["x-custom-tag", "arxiv:2304.12244"], + ... repo_url="https://github.com/huggingface/my-cool-library", + ... docs_url="https://huggingface.co/docs/my-cool-library", + ... # ^ optional metadata to generate model card + ... ): + ... def __init__(self, size: int = 512, device: str = "cpu"): + ... # define how to initialize your model + ... super().__init__() + ... ... + ... + ... def _save_pretrained(self, save_directory: Path) -> None: + ... # define how to serialize your model + ... ... + ... + ... @classmethod + ... def from_pretrained( + ... cls: Type[T], + ... pretrained_model_name_or_path: Union[str, Path], + ... *, + ... force_download: bool = False, + ... resume_download: Optional[bool] = None, + ... proxies: Optional[Dict] = None, + ... token: Optional[Union[str, bool]] = None, + ... cache_dir: Optional[Union[str, Path]] = None, + ... local_files_only: bool = False, + ... revision: Optional[str] = None, + ... **model_kwargs, + ... ) -> T: + ... # define how to deserialize your model + ... ... + + >>> model = MyCustomModel(size=256, device="gpu") + + # Save model weights to local directory + >>> model.save_pretrained("my-awesome-model") + + # Push model weights to the Hub + >>> model.push_to_hub("my-awesome-model") + + # Download and initialize weights from the Hub + >>> reloaded_model = MyCustomModel.from_pretrained("username/my-awesome-model") + >>> reloaded_model.size + 256 + + # Model card has been correctly populated + >>> from huggingface_hub import ModelCard + >>> card = ModelCard.load("username/my-awesome-model") + >>> card.data.tags + ["x-custom-tag", "pytorch_model_hub_mixin", "model_hub_mixin"] + >>> card.data.library_name + "my-library" + ``` + """ + + _hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None + # ^ optional config attribute automatically set in `from_pretrained` + _hub_mixin_info: MixinInfo + # ^ information about the library integrating ModelHubMixin (used to generate model card) + _hub_mixin_inject_config: bool # whether `_from_pretrained` expects `config` or not + _hub_mixin_init_parameters: Dict[str, inspect.Parameter] # __init__ parameters + _hub_mixin_jsonable_default_values: Dict[str, Any] # default values for __init__ parameters + _hub_mixin_jsonable_custom_types: Tuple[Type, ...] # custom types that can be encoded/decoded + _hub_mixin_coders: Dict[Type, CODER_T] # encoders/decoders for custom types + # ^ internal values to handle config + + def __init_subclass__( + cls, + *, + # Generic info for model card + repo_url: Optional[str] = None, + docs_url: Optional[str] = None, + # Model card template + model_card_template: str = DEFAULT_MODEL_CARD, + # Model card metadata + language: Optional[List[str]] = None, + library_name: Optional[str] = None, + license: Optional[str] = None, + license_name: Optional[str] = None, + license_link: Optional[str] = None, + pipeline_tag: Optional[str] = None, + tags: Optional[List[str]] = None, + # How to encode/decode arguments with custom type into a JSON config? + coders: Optional[ + Dict[Type, CODER_T] + # Key is a type. + # Value is a tuple (encoder, decoder). + # Example: {MyCustomType: (lambda x: x.value, lambda data: MyCustomType(data))} + ] = None, + # Deprecated arguments + languages: Optional[List[str]] = None, + ) -> None: + """Inspect __init__ signature only once when subclassing + handle modelcard.""" + super().__init_subclass__() + + # Will be reused when creating modelcard + tags = tags or [] + tags.append("model_hub_mixin") + + # Initialize MixinInfo if not existent + info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData()) + + # If parent class has a MixinInfo, inherit from it as a copy + if hasattr(cls, "_hub_mixin_info"): + # Inherit model card template from parent class if not explicitly set + if model_card_template == DEFAULT_MODEL_CARD: + info.model_card_template = cls._hub_mixin_info.model_card_template + + # Inherit from parent model card data + info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict()) + + # Inherit other info + info.docs_url = cls._hub_mixin_info.docs_url + info.repo_url = cls._hub_mixin_info.repo_url + cls._hub_mixin_info = info + + if languages is not None: + warnings.warn( + "The `languages` argument is deprecated. Use `language` instead. This will be removed in `huggingface_hub>=0.27.0`.", + DeprecationWarning, + ) + language = languages + + # Update MixinInfo with metadata + if model_card_template is not None and model_card_template != DEFAULT_MODEL_CARD: + info.model_card_template = model_card_template + if repo_url is not None: + info.repo_url = repo_url + if docs_url is not None: + info.docs_url = docs_url + if language is not None: + info.model_card_data.language = language + if library_name is not None: + info.model_card_data.library_name = library_name + if license is not None: + info.model_card_data.license = license + if license_name is not None: + info.model_card_data.license_name = license_name + if license_link is not None: + info.model_card_data.license_link = license_link + if pipeline_tag is not None: + info.model_card_data.pipeline_tag = pipeline_tag + if tags is not None: + if info.model_card_data.tags is not None: + info.model_card_data.tags.extend(tags) + else: + info.model_card_data.tags = tags + + info.model_card_data.tags = sorted(set(info.model_card_data.tags)) + + # Handle encoders/decoders for args + cls._hub_mixin_coders = coders or {} + cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys()) + + # Inspect __init__ signature to handle config + cls._hub_mixin_init_parameters = dict(inspect.signature(cls.__init__).parameters) + cls._hub_mixin_jsonable_default_values = { + param.name: cls._encode_arg(param.default) + for param in cls._hub_mixin_init_parameters.values() + if param.default is not inspect.Parameter.empty and cls._is_jsonable(param.default) + } + cls._hub_mixin_inject_config = "config" in inspect.signature(cls._from_pretrained).parameters + + def __new__(cls, *args, **kwargs) -> "ModelHubMixin": + """Create a new instance of the class and handle config. + + 3 cases: + - If `self._hub_mixin_config` is already set, do nothing. + - If `config` is passed as a dataclass, set it as `self._hub_mixin_config`. + - Otherwise, build `self._hub_mixin_config` from default values and passed values. + """ + instance = super().__new__(cls) + + # If `config` is already set, return early + if instance._hub_mixin_config is not None: + return instance + + # Infer passed values + passed_values = { + **{ + key: value + for key, value in zip( + # [1:] to skip `self` parameter + list(cls._hub_mixin_init_parameters)[1:], + args, + ) + }, + **kwargs, + } + + # If config passed as dataclass => set it and return early + if is_dataclass(passed_values.get("config")): + instance._hub_mixin_config = passed_values["config"] + return instance + + # Otherwise, build config from default + passed values + init_config = { + # default values + **cls._hub_mixin_jsonable_default_values, + # passed values + **{ + key: cls._encode_arg(value) # Encode custom types as jsonable value + for key, value in passed_values.items() + if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder + }, + } + passed_config = init_config.pop("config", {}) + + # Populate `init_config` with provided config + if isinstance(passed_config, dict): + init_config.update(passed_config) + + # Set `config` attribute and return + if init_config != {}: + instance._hub_mixin_config = init_config + return instance + + @classmethod + def _is_jsonable(cls, value: Any) -> bool: + """Check if a value is JSON serializable.""" + if isinstance(value, cls._hub_mixin_jsonable_custom_types): + return True + return is_jsonable(value) + + @classmethod + def _encode_arg(cls, arg: Any) -> Any: + """Encode an argument into a JSON serializable format.""" + for type_, (encoder, _) in cls._hub_mixin_coders.items(): + if isinstance(arg, type_): + if arg is None: + return None + return encoder(arg) + return arg + + @classmethod + def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]: + """Decode a JSON serializable value into an argument.""" + if is_simple_optional_type(expected_type): + if value is None: + return None + expected_type = unwrap_simple_optional_type(expected_type) + # Dataclass => handle it + if is_dataclass(expected_type): + return _load_dataclass(expected_type, value) # type: ignore[return-value] + # Otherwise => check custom decoders + for type_, (_, decoder) in cls._hub_mixin_coders.items(): + if inspect.isclass(expected_type) and issubclass(expected_type, type_): + return decoder(value) + # Otherwise => don't decode + return value + + def save_pretrained( + self, + save_directory: Union[str, Path], + *, + config: Optional[Union[dict, "DataclassInstance"]] = None, + repo_id: Optional[str] = None, + push_to_hub: bool = False, + model_card_kwargs: Optional[Dict[str, Any]] = None, + **push_to_hub_kwargs, + ) -> Optional[str]: + """ + Save weights in local directory. + + Args: + save_directory (`str` or `Path`): + Path to directory in which the model weights and configuration will be saved. + config (`dict` or `DataclassInstance`, *optional*): + Model configuration specified as a key/value dictionary or a dataclass instance. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Huggingface Hub after saving it. + repo_id (`str`, *optional*): + ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if + not provided. + model_card_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments passed to the model card template to customize the model card. + push_to_hub_kwargs: + Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. + Returns: + `str` or `None`: url of the commit on the Hub if `push_to_hub=True`, `None` otherwise. + """ + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + # Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json + # as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite + # an existing config.json if it was not saved by `_save_pretrained`. + config_path = save_directory / constants.CONFIG_NAME + config_path.unlink(missing_ok=True) + + # save model weights/files (framework-specific) + self._save_pretrained(save_directory) + + # save config (if provided and if not serialized yet in `_save_pretrained`) + if config is None: + config = self._hub_mixin_config + if config is not None: + if is_dataclass(config): + config = asdict(config) # type: ignore[arg-type] + if not config_path.exists(): + config_str = json.dumps(config, sort_keys=True, indent=2) + config_path.write_text(config_str) + + # save model card + model_card_path = save_directory / "README.md" + model_card_kwargs = model_card_kwargs if model_card_kwargs is not None else {} + if not model_card_path.exists(): # do not overwrite if already exists + self.generate_model_card(**model_card_kwargs).save(save_directory / "README.md") + + # push to the Hub if required + if push_to_hub: + kwargs = push_to_hub_kwargs.copy() # soft-copy to avoid mutating input + if config is not None: # kwarg for `push_to_hub` + kwargs["config"] = config + if repo_id is None: + repo_id = save_directory.name # Defaults to `save_directory` name + return self.push_to_hub(repo_id=repo_id, model_card_kwargs=model_card_kwargs, **kwargs) + return None + + def _save_pretrained(self, save_directory: Path) -> None: + """ + Overwrite this method in subclass to define how to save your model. + Check out our [integration guide](../guides/integrations) for instructions. + + Args: + save_directory (`str` or `Path`): + Path to directory in which the model weights and configuration will be saved. + """ + raise NotImplementedError + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls: Type[T], + pretrained_model_name_or_path: Union[str, Path], + *, + force_download: bool = False, + resume_download: Optional[bool] = None, + proxies: Optional[Dict] = None, + token: Optional[Union[str, bool]] = None, + cache_dir: Optional[Union[str, Path]] = None, + local_files_only: bool = False, + revision: Optional[str] = None, + **model_kwargs, + ) -> T: + """ + Download a model from the Huggingface Hub and instantiate it. + + Args: + pretrained_model_name_or_path (`str`, `Path`): + - Either the `model_id` (string) of a model hosted on the Hub, e.g. `bigscience/bloom`. + - Or a path to a `directory` containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `../path/to/my_model_directory/`. + revision (`str`, *optional*): + Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. + Defaults to the latest commit on `main` branch. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding + the existing cache. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on every request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `huggingface-cli login`. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the local cached file if it exists. + model_kwargs (`Dict`, *optional*): + Additional kwargs to pass to the model during initialization. + """ + model_id = str(pretrained_model_name_or_path) + config_file: Optional[str] = None + if os.path.isdir(model_id): + if constants.CONFIG_NAME in os.listdir(model_id): + config_file = os.path.join(model_id, constants.CONFIG_NAME) + else: + logger.warning(f"{constants.CONFIG_NAME} not found in {Path(model_id).resolve()}") + else: + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=constants.CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except HfHubHTTPError as e: + logger.info(f"{constants.CONFIG_NAME} not found on the HuggingFace Hub: {str(e)}") + + # Read config + config = None + if config_file is not None: + with open(config_file, "r", encoding="utf-8") as f: + config = json.load(f) + + # Decode custom types in config + for key, value in config.items(): + if key in cls._hub_mixin_init_parameters: + expected_type = cls._hub_mixin_init_parameters[key].annotation + if expected_type is not inspect.Parameter.empty: + config[key] = cls._decode_arg(expected_type, value) + + # Populate model_kwargs from config + for param in cls._hub_mixin_init_parameters.values(): + if param.name not in model_kwargs and param.name in config: + model_kwargs[param.name] = config[param.name] + + # Check if `config` argument was passed at init + if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs: + # Decode `config` argument if it was passed + config_annotation = cls._hub_mixin_init_parameters["config"].annotation + config = cls._decode_arg(config_annotation, config) + + # Forward config to model initialization + model_kwargs["config"] = config + + # Inject config if `**kwargs` are expected + if is_dataclass(cls): + for key in cls.__dataclass_fields__: + if key not in model_kwargs and key in config: + model_kwargs[key] = config[key] + elif any(param.kind == inspect.Parameter.VAR_KEYWORD for param in cls._hub_mixin_init_parameters.values()): + for key, value in config.items(): + if key not in model_kwargs: + model_kwargs[key] = value + + # Finally, also inject if `_from_pretrained` expects it + if cls._hub_mixin_inject_config and "config" not in model_kwargs: + model_kwargs["config"] = config + + instance = cls._from_pretrained( + model_id=str(model_id), + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + **model_kwargs, + ) + + # Implicitly set the config as instance attribute if not already set by the class + # This way `config` will be available when calling `save_pretrained` or `push_to_hub`. + if config is not None and (getattr(instance, "_hub_mixin_config", None) in (None, {})): + instance._hub_mixin_config = config + + return instance + + @classmethod + def _from_pretrained( + cls: Type[T], + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: Optional[bool], + local_files_only: bool, + token: Optional[Union[str, bool]], + **model_kwargs, + ) -> T: + """Overwrite this method in subclass to define how to load your model from pretrained. + + Use [`hf_hub_download`] or [`snapshot_download`] to download files from the Hub before loading them. Most + args taken as input can be directly passed to those 2 methods. If needed, you can add more arguments to this + method using "model_kwargs". For example [`PyTorchModelHubMixin._from_pretrained`] takes as input a `map_location` + parameter to set on which device the model should be loaded. + + Check out our [integration guide](../guides/integrations) for more instructions. + + Args: + model_id (`str`): + ID of the model to load from the Huggingface Hub (e.g. `bigscience/bloom`). + revision (`str`, *optional*): + Revision of the model on the Hub. Can be a branch name, a git tag or any commit id. Defaults to the + latest commit on `main` branch. + force_download (`bool`, *optional*, defaults to `False`): + Whether to force (re-)downloading the model weights and configuration files from the Hub, overriding + the existing cache. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint (e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `huggingface-cli login`. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the local cached file if it exists. + model_kwargs: + Additional keyword arguments passed along to the [`~ModelHubMixin._from_pretrained`] method. + """ + raise NotImplementedError + + @validate_hf_hub_args + def push_to_hub( + self, + repo_id: str, + *, + config: Optional[Union[dict, "DataclassInstance"]] = None, + commit_message: str = "Push model using huggingface_hub.", + private: bool = False, + token: Optional[str] = None, + branch: Optional[str] = None, + create_pr: Optional[bool] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + model_card_kwargs: Optional[Dict[str, Any]] = None, + ) -> str: + """ + Upload model checkpoint to the Hub. + + Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use + `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more + details. + + Args: + repo_id (`str`): + ID of the repository to push to (example: `"username/my-model"`). + config (`dict` or `DataclassInstance`, *optional*): + Model configuration specified as a key/value dictionary or a dataclass instance. + commit_message (`str`, *optional*): + Message to commit while pushing. + private (`bool`, *optional*, defaults to `False`): + Whether the repository created should be private. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. By default, it will use the token + cached when running `huggingface-cli login`. + branch (`str`, *optional*): + The git branch on which to push the model. This defaults to `"main"`. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `branch` with that commit. Defaults to `False`. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are pushed. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not pushed. + delete_patterns (`List[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo. + model_card_kwargs (`Dict[str, Any]`, *optional*): + Additional arguments passed to the model card template to customize the model card. + + Returns: + The url of the commit of your model in the given repository. + """ + api = HfApi(token=token) + repo_id = api.create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id + + # Push the files to the repo in a single commit + with SoftTemporaryDirectory() as tmp: + saved_path = Path(tmp) / repo_id + self.save_pretrained(saved_path, config=config, model_card_kwargs=model_card_kwargs) + return api.upload_folder( + repo_id=repo_id, + repo_type="model", + folder_path=saved_path, + commit_message=commit_message, + revision=branch, + create_pr=create_pr, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + delete_patterns=delete_patterns, + ) + + def generate_model_card(self, *args, **kwargs) -> ModelCard: + card = ModelCard.from_template( + card_data=self._hub_mixin_info.model_card_data, + template_str=self._hub_mixin_info.model_card_template, + repo_url=self._hub_mixin_info.repo_url, + docs_url=self._hub_mixin_info.docs_url, + **kwargs, + ) + return card + + +class PyTorchModelHubMixin(ModelHubMixin): + """ + Implementation of [`ModelHubMixin`] to provide model Hub upload/download capabilities to PyTorch models. The model + is set in evaluation mode by default using `model.eval()` (dropout modules are deactivated). To train the model, + you should first set it back in training mode with `model.train()`. + + See [`ModelHubMixin`] for more details on how to use the mixin. + + Example: + + ```python + >>> import torch + >>> import torch.nn as nn + >>> from huggingface_hub import PyTorchModelHubMixin + + >>> class MyModel( + ... nn.Module, + ... PyTorchModelHubMixin, + ... library_name="keras-nlp", + ... repo_url="https://github.com/keras-team/keras-nlp", + ... docs_url="https://keras.io/keras_nlp/", + ... # ^ optional metadata to generate model card + ... ): + ... def __init__(self, hidden_size: int = 512, vocab_size: int = 30000, output_size: int = 4): + ... super().__init__() + ... self.param = nn.Parameter(torch.rand(hidden_size, vocab_size)) + ... self.linear = nn.Linear(output_size, vocab_size) + + ... def forward(self, x): + ... return self.linear(x + self.param) + >>> model = MyModel(hidden_size=256) + + # Save model weights to local directory + >>> model.save_pretrained("my-awesome-model") + + # Push model weights to the Hub + >>> model.push_to_hub("my-awesome-model") + + # Download and initialize weights from the Hub + >>> model = MyModel.from_pretrained("username/my-awesome-model") + >>> model.hidden_size + 256 + ``` + """ + + def __init_subclass__(cls, *args, tags: Optional[List[str]] = None, **kwargs) -> None: + tags = tags or [] + tags.append("pytorch_model_hub_mixin") + kwargs["tags"] = tags + return super().__init_subclass__(*args, **kwargs) + + def _save_pretrained(self, save_directory: Path) -> None: + """Save weights from a Pytorch model to a local directory.""" + model_to_save = self.module if hasattr(self, "module") else self # type: ignore + save_model_as_safetensor(model_to_save, str(save_directory / constants.SAFETENSORS_SINGLE_FILE)) + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: Optional[bool], + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", + strict: bool = False, + **model_kwargs, + ): + """Load Pytorch pretrained weights and return the loaded model.""" + model = cls(**model_kwargs) + if os.path.isdir(model_id): + print("Loading weights from local directory") + model_file = os.path.join(model_id, constants.SAFETENSORS_SINGLE_FILE) + return cls._load_as_safetensor(model, model_file, map_location, strict) + else: + try: + model_file = hf_hub_download( + repo_id=model_id, + filename=constants.SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + return cls._load_as_safetensor(model, model_file, map_location, strict) + except EntryNotFoundError: + model_file = hf_hub_download( + repo_id=model_id, + filename=constants.PYTORCH_WEIGHTS_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + return cls._load_as_pickle(model, model_file, map_location, strict) + + @classmethod + def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + state_dict = torch.load(model_file, map_location=torch.device(map_location), weights_only=True) + model.load_state_dict(state_dict, strict=strict) # type: ignore + model.eval() # type: ignore + return model + + @classmethod + def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: + if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # type: ignore [attr-defined] + load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type] + if map_location != "cpu": + logger.warning( + "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." + " This means that the model is loaded on 'cpu' first and then copied to the device." + " This leads to a slower loading time." + " Please update safetensors to version 0.4.3 or above for improved performance." + ) + model.to(map_location) # type: ignore [attr-defined] + else: + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) # type: ignore [arg-type] + return model + + +def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance": + """Load a dataclass instance from a dictionary. + + Fields not expected by the dataclass are ignored. + """ + return datacls(**{k: v for k, v in data.items() if k in datacls.__dataclass_fields__}) diff --git a/huggingface_hub/inference/__init__.py b/huggingface_hub/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/huggingface_hub/inference/_client.py b/huggingface_hub/inference/_client.py new file mode 100644 index 0000000000000000000000000000000000000000..f6421766750bd58170b18d3ac6703d1ce20e4ff6 --- /dev/null +++ b/huggingface_hub/inference/_client.py @@ -0,0 +1,2789 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Related resources: +# https://huggingface.co/tasks +# https://huggingface.co/docs/huggingface.js/inference/README +# https://github.com/huggingface/huggingface.js/tree/main/packages/inference/src +# https://github.com/huggingface/text-generation-inference/tree/main/clients/python +# https://github.com/huggingface/text-generation-inference/blob/main/clients/python/text_generation/client.py +# https://huggingface.slack.com/archives/C03E4DQ9LAJ/p1680169099087869 +# https://github.com/huggingface/unity-api#tasks +# +# Some TODO: +# - add all tasks +# +# NOTE: the philosophy of this client is "let's make it as easy as possible to use it, even if less optimized". Some +# examples of how it translates: +# - Timeout / Server unavailable is handled by the client in a single "timeout" parameter. +# - Files can be provided as bytes, file paths, or URLs and the client will try to "guess" the type. +# - Images are parsed as PIL.Image for easier manipulation. +# - Provides a "recommended model" for each task => suboptimal but user-wise quicker to get a first script running. +# - Only the main parameters are publicly exposed. Power users can always read the docs for more options. +import base64 +import logging +import re +import time +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Literal, + Optional, + Union, + overload, +) + +from requests import HTTPError +from requests.structures import CaseInsensitiveDict + +from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS +from huggingface_hub.errors import BadRequestError, InferenceTimeoutError +from huggingface_hub.inference._common import ( + TASKS_EXPECTING_IMAGES, + ContentT, + ModelStatus, + _b64_encode, + _b64_to_image, + _bytes_to_dict, + _bytes_to_image, + _bytes_to_list, + _fetch_recommended_models, + _get_unsupported_text_generation_kwargs, + _import_numpy, + _open_as_binary, + _set_unsupported_text_generation_kwargs, + _stream_chat_completion_response, + _stream_text_generation_response, + raise_text_generation_error, +) +from huggingface_hub.inference._generated.types import ( + AudioClassificationOutputElement, + AudioToAudioOutputElement, + AutomaticSpeechRecognitionOutput, + ChatCompletionInputGrammarType, + ChatCompletionInputTool, + ChatCompletionInputToolTypeClass, + ChatCompletionOutput, + ChatCompletionStreamOutput, + DocumentQuestionAnsweringOutputElement, + FillMaskOutputElement, + ImageClassificationOutputElement, + ImageSegmentationOutputElement, + ImageToTextOutput, + ObjectDetectionOutputElement, + QuestionAnsweringOutputElement, + SummarizationOutput, + TableQuestionAnsweringOutputElement, + TextClassificationOutputElement, + TextGenerationInputGrammarType, + TextGenerationOutput, + TextGenerationStreamOutput, + TokenClassificationOutputElement, + TranslationOutput, + VisualQuestionAnsweringOutputElement, + ZeroShotClassificationOutputElement, + ZeroShotImageClassificationOutputElement, +) +from huggingface_hub.utils import ( + build_hf_headers, + get_session, + hf_raise_for_status, +) +from huggingface_hub.utils._deprecation import _deprecate_positional_args + + +if TYPE_CHECKING: + import numpy as np + from PIL.Image import Image + +logger = logging.getLogger(__name__) + + +MODEL_KWARGS_NOT_USED_REGEX = re.compile(r"The following `model_kwargs` are not used by the model: \[(.*?)\]") + + +class InferenceClient: + """ + Initialize a new Inference Client. + + [`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used + seamlessly with either the (free) Inference API or self-hosted Inference Endpoints. + + Args: + model (`str`, `optional`): + The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct` + or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is + automatically selected for the task. + Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2 + arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix + path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) + documentation for details). When passing a URL as `model`, the client will not append any suffix path to it. + token (`str` or `bool`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + Pass `token=False` if you don't want to send your token to the server. + Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2 + arguments are mutually exclusive and have the exact same behavior. + timeout (`float`, `optional`): + The maximum number of seconds to wait for a response from the server. Loading a new model in Inference + API can take up to several minutes. Defaults to None, meaning it will loop until the server is available. + headers (`Dict[str, str]`, `optional`): + Additional headers to send to the server. By default only the authorization and user-agent headers are sent. + Values in this dictionary will override the default values. + cookies (`Dict[str, str]`, `optional`): + Additional cookies to send to the server. + proxies (`Any`, `optional`): + Proxies to use for the request. + base_url (`str`, `optional`): + Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`] + follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None. + api_key (`str`, `optional`): + Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`] + follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None. + """ + + @_deprecate_positional_args(version="0.26") + def __init__( + self, + model: Optional[str] = None, + *, + token: Union[str, bool, None] = None, + timeout: Optional[float] = None, + headers: Optional[Dict[str, str]] = None, + cookies: Optional[Dict[str, str]] = None, + proxies: Optional[Any] = None, + # OpenAI compatibility + base_url: Optional[str] = None, + api_key: Optional[str] = None, + ) -> None: + if model is not None and base_url is not None: + raise ValueError( + "Received both `model` and `base_url` arguments. Please provide only one of them." + " `base_url` is an alias for `model` to make the API compatible with OpenAI's client." + " If using `base_url` for chat completion, the `/chat/completions` suffix path will be appended to the base url." + " When passing a URL as `model`, the client will not append any suffix path to it." + ) + if token is not None and api_key is not None: + raise ValueError( + "Received both `token` and `api_key` arguments. Please provide only one of them." + " `api_key` is an alias for `token` to make the API compatible with OpenAI's client." + " It has the exact same behavior as `token`." + ) + + self.model: Optional[str] = model + self.token: Union[str, bool, None] = token if token is not None else api_key + self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent' + if headers is not None: + self.headers.update(headers) + self.cookies = cookies + self.timeout = timeout + self.proxies = proxies + + # OpenAI compatibility + self.base_url = base_url + + def __repr__(self): + return f"" + + @overload + def post( # type: ignore[misc] + self, + *, + json: Optional[Union[str, Dict, List]] = None, + data: Optional[ContentT] = None, + model: Optional[str] = None, + task: Optional[str] = None, + stream: Literal[False] = ..., + ) -> bytes: ... + + @overload + def post( # type: ignore[misc] + self, + *, + json: Optional[Union[str, Dict, List]] = None, + data: Optional[ContentT] = None, + model: Optional[str] = None, + task: Optional[str] = None, + stream: Literal[True] = ..., + ) -> Iterable[bytes]: ... + + @overload + def post( + self, + *, + json: Optional[Union[str, Dict, List]] = None, + data: Optional[ContentT] = None, + model: Optional[str] = None, + task: Optional[str] = None, + stream: bool = False, + ) -> Union[bytes, Iterable[bytes]]: ... + + def post( + self, + *, + json: Optional[Union[str, Dict, List]] = None, + data: Optional[ContentT] = None, + model: Optional[str] = None, + task: Optional[str] = None, + stream: bool = False, + ) -> Union[bytes, Iterable[bytes]]: + """ + Make a POST request to the inference server. + + Args: + json (`Union[str, Dict, List]`, *optional*): + The JSON data to send in the request body, specific to each task. Defaults to None. + data (`Union[str, Path, bytes, BinaryIO]`, *optional*): + The content to send in the request body, specific to each task. + It can be raw bytes, a pointer to an opened file, a local file path, + or a URL to an online resource (image, audio file,...). If both `json` and `data` are passed, + `data` will take precedence. At least `json` or `data` must be provided. Defaults to None. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. Will override the model defined at the instance level. Defaults to None. + task (`str`, *optional*): + The task to perform on the inference. All available tasks can be found + [here](https://huggingface.co/tasks). Used only to default to a recommended model if `model` is not + provided. At least `model` or `task` must be provided. Defaults to None. + stream (`bool`, *optional*): + Whether to iterate over streaming APIs. + + Returns: + bytes: The raw bytes returned by the server. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + """ + url = self._resolve_url(model, task) + + if data is not None and json is not None: + warnings.warn("Ignoring `json` as `data` is passed as binary.") + + # Set Accept header if relevant + headers = self.headers.copy() + if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers: + headers["Accept"] = "image/png" + + t0 = time.time() + timeout = self.timeout + while True: + with _open_as_binary(data) as data_as_binary: + try: + response = get_session().post( + url, + json=json, + data=data_as_binary, + headers=headers, + cookies=self.cookies, + timeout=self.timeout, + stream=stream, + proxies=self.proxies, + ) + except TimeoutError as error: + # Convert any `TimeoutError` to a `InferenceTimeoutError` + raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore + + try: + hf_raise_for_status(response) + return response.iter_lines() if stream else response.content + except HTTPError as error: + if error.response.status_code == 422 and task is not None: + error.args = ( + f"{error.args[0]}\nMake sure '{task}' task is supported by the model.", + ) + error.args[1:] + if error.response.status_code == 503: + # If Model is unavailable, either raise a TimeoutError... + if timeout is not None and time.time() - t0 > timeout: + raise InferenceTimeoutError( + f"Model not loaded on the server: {url}. Please retry with a higher timeout (current:" + f" {self.timeout}).", + request=error.request, + response=error.response, + ) from error + # ...or wait 1s and retry + logger.info(f"Waiting for model to be loaded on the server: {error}") + time.sleep(1) + if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT): + headers["X-wait-for-model"] = "1" + if timeout is not None: + timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore + continue + raise + + def audio_classification( + self, + audio: ContentT, + *, + model: Optional[str] = None, + ) -> List[AudioClassificationOutputElement]: + """ + Perform audio classification on the provided audio content. + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The audio content to classify. It can be raw audio bytes, a local audio file, or a URL pointing to an + audio file. + model (`str`, *optional*): + The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub + or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for + audio classification will be used. + + Returns: + `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.audio_classification("audio.flac") + [ + AudioClassificationOutputElement(score=0.4976358711719513, label='hap'), + AudioClassificationOutputElement(score=0.3677836060523987, label='neu'), + ... + ] + ``` + """ + response = self.post(data=audio, model=model, task="audio-classification") + return AudioClassificationOutputElement.parse_obj_as_list(response) + + def audio_to_audio( + self, + audio: ContentT, + *, + model: Optional[str] = None, + ) -> List[AudioToAudioOutputElement]: + """ + Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation). + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The audio content for the model. It can be raw audio bytes, a local audio file, or a URL pointing to an + audio file. + model (`str`, *optional*): + The model can be any model which takes an audio file and returns another audio file. Can be a model ID hosted on the Hugging Face Hub + or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for + audio_to_audio will be used. + + Returns: + `List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob. + + Raises: + `InferenceTimeoutError`: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> audio_output = client.audio_to_audio("audio.flac") + >>> for i, item in enumerate(audio_output): + >>> with open(f"output_{i}.flac", "wb") as f: + f.write(item.blob) + ``` + """ + response = self.post(data=audio, model=model, task="audio-to-audio") + audio_output = AudioToAudioOutputElement.parse_obj_as_list(response) + for item in audio_output: + item.blob = base64.b64decode(item.blob) + return audio_output + + def automatic_speech_recognition( + self, + audio: ContentT, + *, + model: Optional[str] = None, + ) -> AutomaticSpeechRecognitionOutput: + """ + Perform automatic speech recognition (ASR or audio-to-text) on the given audio content. + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The content to transcribe. It can be raw audio bytes, local audio file, or a URL to an audio file. + model (`str`, *optional*): + The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended model for ASR will be used. + + Returns: + [`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.automatic_speech_recognition("hello_world.flac").text + "hello world" + ``` + """ + response = self.post(data=audio, model=model, task="automatic-speech-recognition") + return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response) + + @overload + def chat_completion( # type: ignore + self, + messages: List[Dict[str, str]], + *, + model: Optional[str] = None, + stream: Literal[False] = False, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + ) -> ChatCompletionOutput: ... + + @overload + def chat_completion( # type: ignore + self, + messages: List[Dict[str, str]], + *, + model: Optional[str] = None, + stream: Literal[True] = True, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + ) -> Iterable[ChatCompletionStreamOutput]: ... + + @overload + def chat_completion( + self, + messages: List[Dict[str, str]], + *, + model: Optional[str] = None, + stream: bool = False, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: ... + + def chat_completion( + self, + messages: List[Dict[str, str]], + *, + model: Optional[str] = None, + stream: bool = False, + # Parameters from ChatCompletionInput (handled manually) + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + ) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: + """ + A method for completing conversations using a specified language model. + + + + The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client. + Inputs and outputs are strictly the same and using either syntax will yield the same results. + Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility) + for more details about OpenAI's compatibility. + + + + Args: + messages (List[Union[`SystemMessage`, `UserMessage`, `AssistantMessage`]]): + Conversation history consisting of roles and content pairs. + model (`str`, *optional*): + The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used. + See https://huggingface.co/tasks/text-generation for more details. + + If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a + custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`]. + frequency_penalty (`float`, *optional*): + Penalizes new tokens based on their existing frequency + in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0. + logit_bias (`List[float]`, *optional*): + Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens + (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, + the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, + but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should + result in a ban or exclusive selection of the relevant token. Defaults to None. + logprobs (`bool`, *optional*): + Whether to return log probabilities of the output tokens or not. If true, returns the log + probabilities of each output token returned in the content of message. + max_tokens (`int`, *optional*): + Maximum number of tokens allowed in the response. Defaults to 20. + n (`int`, *optional*): + UNUSED. + presence_penalty (`float`, *optional*): + Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the + text so far, increasing the model's likelihood to talk about new topics. + response_format ([`ChatCompletionInputGrammarType`], *optional*): + Grammar constraints. Can be either a JSONSchema or a regex. + seed (Optional[`int`], *optional*): + Seed for reproducible control flow. Defaults to None. + stop (Optional[`str`], *optional*): + Up to four strings which trigger the end of the response. + Defaults to None. + stream (`bool`, *optional*): + Enable realtime streaming of responses. Defaults to False. + temperature (`float`, *optional*): + Controls randomness of the generations. Lower values ensure + less random completions. Range: [0, 2]. Defaults to 1.0. + top_logprobs (`int`, *optional*): + An integer between 0 and 5 specifying the number of most likely tokens to return at each token + position, each with an associated log probability. logprobs must be set to true if this parameter is + used. + top_p (`float`, *optional*): + Fraction of the most likely next words to sample from. + Must be between 0 and 1. Defaults to 1.0. + tool_choice ([`ChatCompletionInputToolTypeClass`] or `str`, *optional*): + The tool to use for the completion. Defaults to "auto". + tool_prompt (`str`, *optional*): + A prompt to be appended before the tools. + tools (List of [`ChatCompletionInputTool`], *optional*): + A list of tools the model may call. Currently, only functions are supported as a tool. Use this to + provide a list of functions the model may generate JSON inputs for. + + Returns: + [`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]: + Generated text returned from the server: + - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default). + - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`]. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + + ```py + >>> from huggingface_hub import InferenceClient + >>> messages = [{"role": "user", "content": "What is the capital of France?"}] + >>> client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") + >>> client.chat_completion(messages, max_tokens=100) + ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason='eos_token', + index=0, + message=ChatCompletionOutputMessage( + role='assistant', + content='The capital of France is Paris.', + name=None, + tool_calls=None + ), + logprobs=None + ) + ], + created=1719907176, + id='', + model='meta-llama/Meta-Llama-3-8B-Instruct', + object='text_completion', + system_fingerprint='2.0.4-sha-f426a33', + usage=ChatCompletionOutputUsage( + completion_tokens=8, + prompt_tokens=17, + total_tokens=25 + ) + ) + ``` + + Example (stream=True): + ```py + >>> from huggingface_hub import InferenceClient + >>> messages = [{"role": "user", "content": "What is the capital of France?"}] + >>> client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") + >>> for token in client.chat_completion(messages, max_tokens=10, stream=True): + ... print(token) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504) + (...) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504) + ``` + + Example using OpenAI's syntax: + ```py + # instead of `from openai import OpenAI` + from huggingface_hub import InferenceClient + + # instead of `client = OpenAI(...)` + client = InferenceClient( + base_url=..., + api_key=..., + ) + + output = client.chat.completions.create( + model="meta-llama/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Count to 10"}, + ], + stream=True, + max_tokens=1024, + ) + + for chunk in output: + print(chunk.choices[0].delta.content) + ``` + + Example using tools: + ```py + >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "system", + ... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", + ... }, + ... { + ... "role": "user", + ... "content": "What's the weather like the next 3 days in San Francisco, CA?", + ... }, + ... ] + >>> tools = [ + ... { + ... "type": "function", + ... "function": { + ... "name": "get_current_weather", + ... "description": "Get the current weather", + ... "parameters": { + ... "type": "object", + ... "properties": { + ... "location": { + ... "type": "string", + ... "description": "The city and state, e.g. San Francisco, CA", + ... }, + ... "format": { + ... "type": "string", + ... "enum": ["celsius", "fahrenheit"], + ... "description": "The temperature unit to use. Infer this from the users location.", + ... }, + ... }, + ... "required": ["location", "format"], + ... }, + ... }, + ... }, + ... { + ... "type": "function", + ... "function": { + ... "name": "get_n_day_weather_forecast", + ... "description": "Get an N-day weather forecast", + ... "parameters": { + ... "type": "object", + ... "properties": { + ... "location": { + ... "type": "string", + ... "description": "The city and state, e.g. San Francisco, CA", + ... }, + ... "format": { + ... "type": "string", + ... "enum": ["celsius", "fahrenheit"], + ... "description": "The temperature unit to use. Infer this from the users location.", + ... }, + ... "num_days": { + ... "type": "integer", + ... "description": "The number of days to forecast", + ... }, + ... }, + ... "required": ["location", "format", "num_days"], + ... }, + ... }, + ... }, + ... ] + + >>> response = client.chat_completion( + ... model="meta-llama/Meta-Llama-3-70B-Instruct", + ... messages=messages, + ... tools=tools, + ... tool_choice="auto", + ... max_tokens=500, + ... ) + >>> response.choices[0].message.tool_calls[0].function + ChatCompletionOutputFunctionDefinition( + arguments={ + 'location': 'San Francisco, CA', + 'format': 'fahrenheit', + 'num_days': 3 + }, + name='get_n_day_weather_forecast', + description=None + ) + ``` + + Example using response_format: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "user", + ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", + ... }, + ... ] + >>> response_format = { + ... "type": "json", + ... "value": { + ... "properties": { + ... "location": {"type": "string"}, + ... "activity": {"type": "string"}, + ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, + ... "animals": {"type": "array", "items": {"type": "string"}}, + ... }, + ... "required": ["location", "activity", "animals_seen", "animals"], + ... }, + ... } + >>> response = client.chat_completion( + ... messages=messages, + ... response_format=response_format, + ... max_tokens=500, + ) + >>> response.choices[0].message.content + '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}' + ``` + """ + model_url = self._resolve_chat_completion_url(model) + + # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing. + # If it's a ID on the Hub => use it. Otherwise, we use a random string. + model_id = model or self.model or "tgi" + if model_id.startswith(("http://", "https://")): + model_id = "tgi" # dummy value + + payload = dict( + model=model_id, + messages=messages, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + temperature=temperature, + tool_choice=tool_choice, + tool_prompt=tool_prompt, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + stream=stream, + ) + payload = {key: value for key, value in payload.items() if value is not None} + data = self.post(model=model_url, json=payload, stream=stream) + + if stream: + return _stream_chat_completion_response(data) # type: ignore[arg-type] + + return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + + def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str: + # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. + # `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`. + model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation") + + # Resolve URL if it's a model ID + model_url = ( + model_id_or_url + if model_id_or_url.startswith(("http://", "https://")) + else self._resolve_url(model_id_or_url, task="text-generation") + ) + + # Strip trailing / + model_url = model_url.rstrip("/") + + # Append /chat/completions if not already present + if model_url.endswith("/v1"): + model_url += "/chat/completions" + + # Append /v1/chat/completions if not already present + if not model_url.endswith("/chat/completions"): + model_url += "/v1/chat/completions" + + return model_url + + def document_question_answering( + self, + image: ContentT, + question: str, + *, + model: Optional[str] = None, + ) -> List[DocumentQuestionAnsweringOutputElement]: + """ + Answer questions on document images. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image for the context. It can be raw bytes, an image file, or a URL to an online image. + question (`str`): + Question to be answered. + model (`str`, *optional*): + The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used. + Defaults to None. + + Returns: + `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?") + [DocumentQuestionAnsweringOutputElement(score=0.42515629529953003, answer='us-001', start=16, end=16)] + ``` + """ + payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + response = self.post(json=payload, model=model, task="document-question-answering") + return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response) + + def feature_extraction( + self, + text: str, + *, + normalize: Optional[bool] = None, + prompt_name: Optional[str] = None, + truncate: Optional[bool] = None, + truncation_direction: Optional[Literal["Left", "Right"]] = None, + model: Optional[str] = None, + ) -> "np.ndarray": + """ + Generate embeddings for a given text. + + Args: + text (`str`): + The text to embed. + model (`str`, *optional*): + The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used. + Defaults to None. + normalize (`bool`, *optional*): + Whether to normalize the embeddings or not. Defaults to None. + Only available on server powered by Text-Embedding-Inference. + prompt_name (`str`, *optional*): + The name of the prompt that should be used by for encoding. If not set, no prompt will be applied. + Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. + For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...}, + then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" + because the prompt text will be prepended before any text to encode. + truncate (`bool`, *optional*): + Whether to truncate the embeddings or not. Defaults to None. + Only available on server powered by Text-Embedding-Inference. + truncation_direction (`Literal["Left", "Right"]`, *optional*): + Which side of the input should be truncated when `truncate=True` is passed. + + Returns: + `np.ndarray`: The embedding representing the input text as a float32 numpy array. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.feature_extraction("Hi, who are you?") + array([[ 2.424802 , 2.93384 , 1.1750331 , ..., 1.240499, -0.13776633, -0.7889173 ], + [-0.42943227, -0.6364878 , -1.693462 , ..., 0.41978157, -2.4336355 , 0.6162071 ], + ..., + [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32) + ``` + """ + payload: Dict = {"inputs": text} + if normalize is not None: + payload["normalize"] = normalize + if prompt_name is not None: + payload["prompt_name"] = prompt_name + if truncate is not None: + payload["truncate"] = truncate + if truncation_direction is not None: + payload["truncation_direction"] = truncation_direction + response = self.post(json=payload, model=model, task="feature-extraction") + np = _import_numpy() + return np.array(_bytes_to_dict(response), dtype="float32") + + def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutputElement]: + """ + Fill in a hole with a missing word (token to be precise). + + Args: + text (`str`): + a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask). + model (`str`, *optional*): + The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. + Defaults to None. + + Returns: + `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated + probability, token reference, and completed text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.fill_mask("The goal of life is.") + [ + FillMaskOutputElement(score=0.06897063553333282, token=11098, token_str=' happiness', sequence='The goal of life is happiness.'), + FillMaskOutputElement(score=0.06554922461509705, token=45075, token_str=' immortality', sequence='The goal of life is immortality.') + ] + ``` + """ + response = self.post(json={"inputs": text}, model=model, task="fill-mask") + return FillMaskOutputElement.parse_obj_as_list(response) + + def image_classification( + self, + image: ContentT, + *, + model: Optional[str] = None, + ) -> List[ImageClassificationOutputElement]: + """ + Perform image classification on the given image using the specified model. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The image to classify. It can be raw bytes, an image file, or a URL to an online image. + model (`str`, *optional*): + The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used. + + Returns: + `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") + [ImageClassificationOutputElement(score=0.9779096841812134, label='Blenheim spaniel'), ...] + ``` + """ + response = self.post(data=image, model=model, task="image-classification") + return ImageClassificationOutputElement.parse_obj_as_list(response) + + def image_segmentation( + self, + image: ContentT, + *, + model: Optional[str] = None, + ) -> List[ImageSegmentationOutputElement]: + """ + Perform image segmentation on the given image using the specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The image to segment. It can be raw bytes, an image file, or a URL to an online image. + model (`str`, *optional*): + The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used. + + Returns: + `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.image_segmentation("cat.jpg"): + [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=), ...] + ``` + """ + response = self.post(data=image, model=model, task="image-segmentation") + output = ImageSegmentationOutputElement.parse_obj_as_list(response) + for item in output: + item.mask = _b64_to_image(item.mask) + return output + + def image_to_image( + self, + image: ContentT, + prompt: Optional[str] = None, + *, + negative_prompt: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + model: Optional[str] = None, + **kwargs, + ) -> "Image": + """ + Perform image-to-image translation using a specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image for translation. It can be raw bytes, an image file, or a URL to an online image. + prompt (`str`, *optional*): + The text prompt to guide the image generation. + negative_prompt (`str`, *optional*): + A negative prompt to guide the translation process. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*): + Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `Image`: The translated image. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> image = client.image_to_image("cat.jpg", prompt="turn the cat into a tiger") + >>> image.save("tiger.jpg") + ``` + """ + parameters = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + **kwargs, + } + if all(parameter is None for parameter in parameters.values()): + # Either only an image to send => send as raw bytes + data = image + payload: Optional[Dict[str, Any]] = None + else: + # Or an image + some parameters => use base64 encoding + data = None + payload = {"inputs": _b64_encode(image)} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + + response = self.post(json=payload, data=data, model=model, task="image-to-image") + return _bytes_to_image(response) + + def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: + """ + Takes an input image and return text. + + Models can have very different outputs depending on your use case (image captioning, optical character recognition + (OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image to caption. It can be raw bytes, an image file, or a URL to an online image.. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + [`ImageToTextOutput`]: The generated text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.image_to_text("cat.jpg") + 'a cat standing in a grassy field ' + >>> client.image_to_text("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") + 'a dog laying on the grass next to a flower pot ' + ``` + """ + response = self.post(data=image, model=model, task="image-to-text") + output = ImageToTextOutput.parse_obj(response) + return output[0] if isinstance(output, list) else output + + def list_deployed_models( + self, frameworks: Union[None, str, Literal["all"], List[str]] = None + ) -> Dict[str, List[str]]: + """ + List models deployed on the Serverless Inference API service. + + This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that + are supported and account for 95% of the hosted models. However, if you want a complete list of models you can + specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested + in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more + frameworks are checked, the more time it will take. + ++ + This endpoint method does not return a live list of all models available for the Serverless Inference API service. + It searches over a cached list of models that were recently available and the list may not be up to date. + If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`]. + + + ++ + This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to + check its availability, you can directly use [`~InferenceClient.get_model_status`]. + + + + Args: + frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*): + The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to + "all", all available frameworks will be tested. It is also possible to provide a single framework or a + custom set of frameworks to check. + + Returns: + `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs. + + Example: + ```python + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + + # Discover zero-shot-classification models currently deployed + >>> models = client.list_deployed_models() + >>> models["zero-shot-classification"] + ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...] + + # List from only 1 framework + >>> client.list_deployed_models("text-generation-inference") + {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...} + ``` + """ + # Resolve which frameworks to check + if frameworks is None: + frameworks = MAIN_INFERENCE_API_FRAMEWORKS + elif frameworks == "all": + frameworks = ALL_INFERENCE_API_FRAMEWORKS + elif isinstance(frameworks, str): + frameworks = [frameworks] + frameworks = list(set(frameworks)) + + # Fetch them iteratively + models_by_task: Dict[str, List[str]] = {} + + def _unpack_response(framework: str, items: List[Dict]) -> None: + for model in items: + if framework == "sentence-transformers": + # Model running with the `sentence-transformers` framework can work with both tasks even if not + # branded as such in the API response + models_by_task.setdefault("feature-extraction", []).append(model["model_id"]) + models_by_task.setdefault("sentence-similarity", []).append(model["model_id"]) + else: + models_by_task.setdefault(model["task"], []).append(model["model_id"]) + + for framework in frameworks: + response = get_session().get(f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=self.headers) + hf_raise_for_status(response) + _unpack_response(framework, response.json()) + + # Sort alphabetically for discoverability and return + for task, models in models_by_task.items(): + models_by_task[task] = sorted(set(models), key=lambda x: x.lower()) + return models_by_task + + def object_detection( + self, + image: ContentT, + *, + model: Optional[str] = None, + ) -> List[ObjectDetectionOutputElement]: + """ + Perform object detection on the given image using the specified model. + ++ + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image. + model (`str`, *optional*): + The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used. + + Returns: + `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + `ValueError`: + If the request output is not a List. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.object_detection("people.jpg"): + [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...] + ``` + """ + # detect objects + response = self.post(data=image, model=model, task="object-detection") + return ObjectDetectionOutputElement.parse_obj_as_list(response) + + def question_answering( + self, question: str, context: str, *, model: Optional[str] = None + ) -> QuestionAnsweringOutputElement: + """ + Retrieve the answer to a question from a given text. + + Args: + question (`str`): + Question to be answered. + context (`str`): + The context of the question. + model (`str`): + The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. + + Returns: + [`QuestionAnsweringOutputElement`]: an question answering output containing the score, start index, end index, and answer. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.") + QuestionAnsweringOutputElement(score=0.9326562285423279, start=11, end=16, answer='Clara') + ``` + """ + + payload: Dict[str, Any] = {"question": question, "context": context} + response = self.post( + json=payload, + model=model, + task="question-answering", + ) + return QuestionAnsweringOutputElement.parse_obj_as_instance(response) + + def sentence_similarity( + self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None + ) -> List[float]: + """ + Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings. + + Args: + sentence (`str`): + The main sentence to compare to others. + other_sentences (`List[str]`): + The list of sentences to compare to. + model (`str`, *optional*): + The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used. + Defaults to None. + + Returns: + `List[float]`: The embedding representing the input text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.sentence_similarity( + ... "Machine learning is so easy.", + ... other_sentences=[ + ... "Deep learning is so straightforward.", + ... "This is so difficult, like rocket science.", + ... "I can't believe how much I struggled with this.", + ... ], + ... ) + [0.7785726189613342, 0.45876261591911316, 0.2906220555305481] + ``` + """ + response = self.post( + json={"inputs": {"source_sentence": sentence, "sentences": other_sentences}}, + model=model, + task="sentence-similarity", + ) + return _bytes_to_list(response) + + def summarization( + self, + text: str, + *, + parameters: Optional[Dict[str, Any]] = None, + model: Optional[str] = None, + ) -> SummarizationOutput: + """ + Generate a summary of a given text using a specified model. + + Args: + text (`str`): + The input text to summarize. + parameters (`Dict[str, Any]`, *optional*): + Additional parameters for summarization. Check out this [page](https://huggingface.co/docs/api-inference/detailed_parameters#summarization-task) + for more details. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + [`SummarizationOutput`]: The generated summary text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.summarization("The Eiffel tower...") + SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....") + ``` + """ + payload: Dict[str, Any] = {"inputs": text} + if parameters is not None: + payload["parameters"] = parameters + response = self.post(json=payload, model=model, task="summarization") + return SummarizationOutput.parse_obj_as_list(response)[0] + + def table_question_answering( + self, table: Dict[str, Any], query: str, *, model: Optional[str] = None + ) -> TableQuestionAnsweringOutputElement: + """ + Retrieve the answer to a question from information given in a table. + + Args: + table (`str`): + A table of data represented as a dict of lists where entries are headers and the lists are all the + values, all lists must have the same size. + query (`str`): + The query in plain text that you want to ask the table. + model (`str`): + The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face + Hub or a URL to a deployed Inference Endpoint. + + Returns: + [`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> query = "How many stars does the transformers repository have?" + >>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]} + >>> client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq") + TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE') + ``` + """ + response = self.post( + json={ + "query": query, + "table": table, + }, + model=model, + task="table-question-answering", + ) + return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response) + + def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]: + """ + Classifying a target category (a group) based on a set of attributes. + + Args: + table (`Dict[str, Any]`): + Set of attributes to classify. + model (`str`, *optional*): + The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended tabular classification model will be used. + Defaults to None. + + Returns: + `List`: a list of labels, one per row in the initial table. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> table = { + ... "fixed_acidity": ["7.4", "7.8", "10.3"], + ... "volatile_acidity": ["0.7", "0.88", "0.32"], + ... "citric_acid": ["0", "0", "0.45"], + ... "residual_sugar": ["1.9", "2.6", "6.4"], + ... "chlorides": ["0.076", "0.098", "0.073"], + ... "free_sulfur_dioxide": ["11", "25", "5"], + ... "total_sulfur_dioxide": ["34", "67", "13"], + ... "density": ["0.9978", "0.9968", "0.9976"], + ... "pH": ["3.51", "3.2", "3.23"], + ... "sulphates": ["0.56", "0.68", "0.82"], + ... "alcohol": ["9.4", "9.8", "12.6"], + ... } + >>> client.tabular_classification(table=table, model="julien-c/wine-quality") + ["5", "5", "5"] + ``` + """ + response = self.post(json={"table": table}, model=model, task="tabular-classification") + return _bytes_to_list(response) + + def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]: + """ + Predicting a numerical target value given a set of attributes/features in a table. + + Args: + table (`Dict[str, Any]`): + Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical. + model (`str`, *optional*): + The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended tabular regression model will be used. + Defaults to None. + + Returns: + `List`: a list of predicted numerical target values. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> table = { + ... "Height": ["11.52", "12.48", "12.3778"], + ... "Length1": ["23.2", "24", "23.9"], + ... "Length2": ["25.4", "26.3", "26.5"], + ... "Length3": ["30", "31.2", "31.1"], + ... "Species": ["Bream", "Bream", "Bream"], + ... "Width": ["4.02", "4.3056", "4.6961"], + ... } + >>> client.tabular_regression(table, model="scikit-learn/Fish-Weight") + [110, 120, 130] + ``` + """ + response = self.post(json={"table": table}, model=model, task="tabular-regression") + return _bytes_to_list(response) + + def text_classification(self, text: str, *, model: Optional[str] = None) -> List[TextClassificationOutputElement]: + """ + Perform text classification (e.g. sentiment-analysis) on the given text. + + Args: + text (`str`): + A string to be classified. + model (`str`, *optional*): + The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. + Defaults to None. + + Returns: + `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.text_classification("I like you") + [ + TextClassificationOutputElement(label='POSITIVE', score=0.9998695850372314), + TextClassificationOutputElement(label='NEGATIVE', score=0.0001304351753788069), + ] + ``` + """ + response = self.post(json={"inputs": text}, model=model, task="text-classification") + return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value] + + @overload + def text_generation( # type: ignore + self, + prompt: str, + *, + details: Literal[False] = ..., + stream: Literal[False] = ..., + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> str: ... + + @overload + def text_generation( # type: ignore + self, + prompt: str, + *, + details: Literal[True] = ..., + stream: Literal[False] = ..., + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> TextGenerationOutput: ... + + @overload + def text_generation( # type: ignore + self, + prompt: str, + *, + details: Literal[False] = ..., + stream: Literal[True] = ..., + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Iterable[str]: ... + + @overload + def text_generation( # type: ignore + self, + prompt: str, + *, + details: Literal[True] = ..., + stream: Literal[True] = ..., + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Iterable[TextGenerationStreamOutput]: ... + + @overload + def text_generation( + self, + prompt: str, + *, + details: Literal[True] = ..., + stream: bool = ..., + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Union[TextGenerationOutput, Iterable[TextGenerationStreamOutput]]: ... + + def text_generation( + self, + prompt: str, + *, + details: bool = False, + stream: bool = False, + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]: + """ + Given a prompt, generate the following text. + + API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the + go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the + default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but + not exactly the same. This method is compatible with both approaches but some parameters are only available for + `text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process + continues correctly. + + To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference. + ++ + If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method. + It accepts a list of messages instead of a single text prompt and handles the chat templating for you. + + + + Args: + prompt (`str`): + Input text. + details (`bool`, *optional*): + By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens, + probabilities, seed, finish reason, etc.). Only available for models running on with the + `text-generation-inference` backend. + stream (`bool`, *optional*): + By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of + tokens to be returned. Only available for models running on with the `text-generation-inference` + backend. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + adapter_id (`str`, *optional*): + Lora adapter id. + best_of (`int`, *optional*): + Generate best_of sequences and return the one if the highest token logprobs. + decoder_input_details (`bool`, *optional*): + Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken + into account. Defaults to `False`. + do_sample (`bool`, *optional*): + Activate logits sampling + frequency_penalty (`float`, *optional*): + Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in + the text so far, decreasing the model's likelihood to repeat the same line verbatim. + grammar ([`TextGenerationInputGrammarType`], *optional*): + Grammar constraints. Can be either a JSONSchema or a regex. + max_new_tokens (`int`, *optional*): + Maximum number of generated tokens + repetition_penalty (`float`, *optional*): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + return_full_text (`bool`, *optional*): + Whether to prepend the prompt to the generated text + seed (`int`, *optional*): + Random sampling seed + stop (`List[str]`, *optional*): + Stop generating tokens if a member of `stop` is generated. + stop_sequences (`List[str]`, *optional*): + Deprecated argument. Use `stop` instead. + temperature (`float`, *optional*): + The value used to module the logits distribution. + top_n_tokens (`int`, *optional*): + Return information about the `top_n_tokens` most likely tokens at each generation step, instead of + just the sampled token. + top_k (`int`, *optional`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + truncate (`int`, *optional`): + Truncate inputs tokens to the given size. + typical_p (`float`, *optional`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + watermark (`bool`, *optional`): + Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + + Returns: + `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`: + Generated text returned from the server: + - if `stream=False` and `details=False`, the generated text is returned as a `str` (default) + - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]` + - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`] + - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`] + + Raises: + `ValidationError`: + If input values are not valid. No HTTP call is made to the server. + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + + # Case 1: generate text + >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12) + '100% open source and built to be easy to use.' + + # Case 2: iterate over the generated tokens. Useful for large generation. + >>> for token in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True): + ... print(token) + 100 + % + open + source + and + built + to + be + easy + to + use + . + + # Case 3: get more details about the generation process. + >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True) + TextGenerationOutput( + generated_text='100% open source and built to be easy to use.', + details=TextGenerationDetails( + finish_reason='length', + generated_tokens=12, + seed=None, + prefill=[ + TextGenerationPrefillOutputToken(id=487, text='The', logprob=None), + TextGenerationPrefillOutputToken(id=53789, text=' hugging', logprob=-13.171875), + (...) + TextGenerationPrefillOutputToken(id=204, text=' ', logprob=-7.0390625) + ], + tokens=[ + TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), + TokenElement(id=16, text='%', logprob=-0.0463562, special=False), + (...) + TokenElement(id=25, text='.', logprob=-0.5703125, special=False) + ], + best_of_sequences=None + ) + ) + + # Case 4: iterate over the generated tokens with more details. + # Last object is more complete, containing the full generated text and the finish reason. + >>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True): + ... print(details) + ... + TextGenerationStreamOutput(token=TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement( + id=25, + text='.', + logprob=-0.5703125, + special=False), + generated_text='100% open source and built to be easy to use.', + details=TextGenerationStreamOutputStreamDetails(finish_reason='length', generated_tokens=12, seed=None) + ) + + # Case 5: generate constrained output using grammar + >>> response = client.text_generation( + ... prompt="I saw a puppy a cat and a raccoon during my bike ride in the park", + ... model="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + ... max_new_tokens=100, + ... repetition_penalty=1.3, + ... grammar={ + ... "type": "json", + ... "value": { + ... "properties": { + ... "location": {"type": "string"}, + ... "activity": {"type": "string"}, + ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, + ... "animals": {"type": "array", "items": {"type": "string"}}, + ... }, + ... "required": ["location", "activity", "animals_seen", "animals"], + ... }, + ... }, + ... ) + >>> json.loads(response) + { + "activity": "bike riding", + "animals": ["puppy", "cat", "raccoon"], + "animals_seen": 3, + "location": "park" + } + ``` + """ + if decoder_input_details and not details: + warnings.warn( + "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that" + " the output from the server will be truncated." + ) + decoder_input_details = False + + if stop_sequences is not None: + warnings.warn( + "`stop_sequences` is a deprecated argument for `text_generation` task" + " and will be removed in version '0.28.0'. Use `stop` instead.", + FutureWarning, + ) + if stop is None: + stop = stop_sequences # use deprecated arg if provided + + # Build payload + parameters = { + "adapter_id": adapter_id, + "best_of": best_of, + "decoder_input_details": decoder_input_details, + "details": details, + "do_sample": do_sample, + "frequency_penalty": frequency_penalty, + "grammar": grammar, + "max_new_tokens": max_new_tokens, + "repetition_penalty": repetition_penalty, + "return_full_text": return_full_text, + "seed": seed, + "stop": stop if stop is not None else [], + "temperature": temperature, + "top_k": top_k, + "top_n_tokens": top_n_tokens, + "top_p": top_p, + "truncate": truncate, + "typical_p": typical_p, + "watermark": watermark, + } + parameters = {k: v for k, v in parameters.items() if v is not None} + payload = { + "inputs": prompt, + "parameters": parameters, + "stream": stream, + } + + # Remove some parameters if not a TGI server + unsupported_kwargs = _get_unsupported_text_generation_kwargs(model) + if len(unsupported_kwargs) > 0: + # The server does not support some parameters + # => means it is not a TGI server + # => remove unsupported parameters and warn the user + + ignored_parameters = [] + for key in unsupported_kwargs: + if parameters.get(key): + ignored_parameters.append(key) + parameters.pop(key, None) + if len(ignored_parameters) > 0: + warnings.warn( + "API endpoint/model for text-generation is not served via TGI. Ignoring following parameters:" + f" {', '.join(ignored_parameters)}.", + UserWarning, + ) + if details: + warnings.warn( + "API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will" + " be ignored meaning only the generated text will be returned.", + UserWarning, + ) + details = False + if stream: + raise ValueError( + "API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream." + " Please pass `stream=False` as input." + ) + + # Handle errors separately for more precise error messages + try: + bytes_output = self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore + except HTTPError as e: + match = MODEL_KWARGS_NOT_USED_REGEX.search(str(e)) + if isinstance(e, BadRequestError) and match: + unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")] + _set_unsupported_text_generation_kwargs(model, unused_params) + return self.text_generation( # type: ignore + prompt=prompt, + details=details, + stream=stream, + model=model, + adapter_id=adapter_id, + best_of=best_of, + decoder_input_details=decoder_input_details, + do_sample=do_sample, + frequency_penalty=frequency_penalty, + grammar=grammar, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop, + temperature=temperature, + top_k=top_k, + top_n_tokens=top_n_tokens, + top_p=top_p, + truncate=truncate, + typical_p=typical_p, + watermark=watermark, + ) + raise_text_generation_error(e) + + # Parse output + if stream: + return _stream_text_generation_response(bytes_output, details) # type: ignore + + data = _bytes_to_dict(bytes_output) # type: ignore[arg-type] + + # Data can be a single element (dict) or an iterable of dicts where we select the first element of. + if isinstance(data, list): + data = data[0] + + return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"] + + def text_to_image( + self, + prompt: str, + *, + negative_prompt: Optional[str] = None, + height: Optional[float] = None, + width: Optional[float] = None, + num_inference_steps: Optional[float] = None, + guidance_scale: Optional[float] = None, + model: Optional[str] = None, + **kwargs, + ) -> "Image": + """ + Generate an image based on a given text using a specified model. + ++ + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + prompt (`str`): + The prompt to generate an image from. + negative_prompt (`str`, *optional*): + An optional negative prompt for the image generation. + height (`float`, *optional*): + The height in pixels of the image to generate. + width (`float`, *optional*): + The width in pixels of the image to generate. + num_inference_steps (`int`, *optional*): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*): + Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `Image`: The generated image. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + + >>> image = client.text_to_image("An astronaut riding a horse on the moon.") + >>> image.save("astronaut.png") + + >>> image = client.text_to_image( + ... "An astronaut riding a horse on the moon.", + ... negative_prompt="low resolution, blurry", + ... model="stabilityai/stable-diffusion-2-1", + ... ) + >>> image.save("better_astronaut.png") + ``` + """ + payload = {"inputs": prompt} + parameters = { + "negative_prompt": negative_prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + **kwargs, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value # type: ignore + response = self.post(json=payload, model=model, task="text-to-image") + return _bytes_to_image(response) + + def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes: + """ + Synthesize an audio of a voice pronouncing a given text. + + Args: + text (`str`): + The text to synthesize. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `bytes`: The generated audio. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from pathlib import Path + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + + >>> audio = client.text_to_speech("Hello world") + >>> Path("hello_world.flac").write_bytes(audio) + ``` + """ + return self.post(json={"inputs": text}, model=model, task="text-to-speech") + + def token_classification( + self, text: str, *, model: Optional[str] = None + ) -> List[TokenClassificationOutputElement]: + """ + Perform token classification on the given text. + Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. + + Args: + text (`str`): + A string to be classified. + model (`str`, *optional*): + The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used. + Defaults to None. + + Returns: + `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica") + [ + TokenClassificationOutputElement( + entity_group='PER', + score=0.9971321225166321, + word='Sarah Jessica Parker', + start=11, + end=31, + ), + TokenClassificationOutputElement( + entity_group='PER', + score=0.9773476123809814, + word='Jessica', + start=52, + end=59, + ) + ] + ``` + """ + payload: Dict[str, Any] = {"inputs": text} + response = self.post( + json=payload, + model=model, + task="token-classification", + ) + return TokenClassificationOutputElement.parse_obj_as_list(response) + + def translation( + self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None + ) -> TranslationOutput: + """ + Convert text from one language to another. + + Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for + your specific use case. Source and target languages usually depend on the model. + However, it is possible to specify source and target languages for certain models. If you are working with one of these models, + you can use `src_lang` and `tgt_lang` arguments to pass the relevant information. + You can find this information in the model card. + + Args: + text (`str`): + A string to be translated. + model (`str`, *optional*): + The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended translation model will be used. + Defaults to None. + src_lang (`str`, *optional*): + Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`. + tgt_lang (`str`, *optional*): + Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`. + + Returns: + [`TranslationOutput`]: The generated translated text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + `ValueError`: + If only one of the `src_lang` and `tgt_lang` arguments are provided. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.translation("My name is Wolfgang and I live in Berlin") + 'Mein Name ist Wolfgang und ich lebe in Berlin.' + >>> client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr") + TranslationOutput(translation_text='Je m\'appelle Wolfgang et je vis Γ Berlin.') + ``` + + Specifying languages: + ```py + >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX") + "Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica" + ``` + """ + # Throw error if only one of `src_lang` and `tgt_lang` was given + if src_lang is not None and tgt_lang is None: + raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.") + + if src_lang is None and tgt_lang is not None: + raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") + + # If both `src_lang` and `tgt_lang` are given, pass them to the request body + payload: Dict = {"inputs": text} + if src_lang and tgt_lang: + payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang} + response = self.post(json=payload, model=model, task="translation") + return TranslationOutput.parse_obj_as_list(response)[0] + + def visual_question_answering( + self, + image: ContentT, + question: str, + *, + model: Optional[str] = None, + ) -> List[VisualQuestionAnsweringOutputElement]: + """ + Answering open-ended questions based on an image. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image for the context. It can be raw bytes, an image file, or a URL to an online image. + question (`str`): + Question to be answered. + model (`str`, *optional*): + The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used. + Defaults to None. + + Returns: + `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. + + Raises: + `InferenceTimeoutError`: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.visual_question_answering( + ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg", + ... question="What is the animal doing?" + ... ) + [ + VisualQuestionAnsweringOutputElement(score=0.778609573841095, answer='laying down'), + VisualQuestionAnsweringOutputElement(score=0.6957435607910156, answer='sitting'), + ] + ``` + """ + payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + response = self.post(json=payload, model=model, task="visual-question-answering") + return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response) + + def zero_shot_classification( + self, + text: str, + labels: List[str], + *, + multi_label: bool = False, + hypothesis_template: Optional[str] = None, + model: Optional[str] = None, + ) -> List[ZeroShotClassificationOutputElement]: + """ + Provide as input a text and a set of candidate labels to classify the input text. + + Args: + text (`str`): + The input text to classify. + labels (`List[str]`): + List of strings. Each string is the verbalization of a possible label for the input text. + multi_label (`bool`): + Boolean. If True, the probability for each label is evaluated independently and multiple labels can have a probability close to 1 simultaneously or all probabilities can be close to 0. + If False, the labels are considered mutually exclusive and the probability over all labels always sums to 1. Defaults to False. + hypothesis_template (`str`, *optional*): + A template sentence string with curly brackets to which the label strings are added. The label strings are added at the position of the curly brackets "{}". + Zero-shot classifiers are based on NLI models, which evaluate if a hypothesis is entailed in another text or not. + For example, with hypothesis_template="This text is about {}." and labels=["economics", "politics"], the system internally creates the two hypotheses "This text is about economics." and "This text is about politics.". + The model then evaluates for both hypotheses if they are entailed in the provided `text` or not. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example with `multi_label=False`: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> text = ( + ... "A new model offers an explanation for how the Galilean satellites formed around the solar system's" + ... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling" + ... " mysteries when he went for a run up a hill in Nice, France." + ... ) + >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"] + >>> client.zero_shot_classification(text, labels) + [ + ZeroShotClassificationOutputElement(label='scientific discovery', score=0.7961668968200684), + ZeroShotClassificationOutputElement(label='space & cosmos', score=0.18570658564567566), + ZeroShotClassificationOutputElement(label='microbiology', score=0.00730885099619627), + ZeroShotClassificationOutputElement(label='archeology', score=0.006258360575884581), + ZeroShotClassificationOutputElement(label='robots', score=0.004559356719255447), + ] + >>> client.zero_shot_classification(text, labels, multi_label=True) + [ + ZeroShotClassificationOutputElement(label='scientific discovery', score=0.9829297661781311), + ZeroShotClassificationOutputElement(label='space & cosmos', score=0.755190908908844), + ZeroShotClassificationOutputElement(label='microbiology', score=0.0005462635890580714), + ZeroShotClassificationOutputElement(label='archeology', score=0.00047131875180639327), + ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354), + ] + ``` + + Example with `multi_label=True` and a custom `hypothesis_template`: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.zero_shot_classification( + ... text="I really like our dinner and I'm very happy. I don't like the weather though.", + ... labels=["positive", "negative", "pessimistic", "optimistic"], + ... multi_label=True, + ... hypothesis_template="This text is {} towards the weather" + ... ) + [ + ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467), + ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134), + ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062), + ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363) + ] + ``` + """ + + parameters = {"candidate_labels": labels, "multi_label": multi_label} + if hypothesis_template is not None: + parameters["hypothesis_template"] = hypothesis_template + + response = self.post( + json={ + "inputs": text, + "parameters": parameters, + }, + task="zero-shot-classification", + model=model, + ) + output = _bytes_to_dict(response) + return [ + ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score}) + for label, score in zip(output["labels"], output["scores"]) + ] + + def zero_shot_image_classification( + self, image: ContentT, labels: List[str], *, model: Optional[str] = None + ) -> List[ZeroShotImageClassificationOutputElement]: + """ + Provide input image and text labels to predict text labels for the image. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image to caption. It can be raw bytes, an image file, or a URL to an online image. + labels (`List[str]`): + List of string possible labels. There must be at least 2 labels. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + + >>> client.zero_shot_image_classification( + ... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg", + ... labels=["dog", "cat", "horse"], + ... ) + [ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...] + ``` + """ + # Raise ValueError if input is less than 2 labels + if len(labels) < 2: + raise ValueError("You must specify at least 2 classes to compare.") + + response = self.post( + json={"image": _b64_encode(image), "parameters": {"candidate_labels": ",".join(labels)}}, + model=model, + task="zero-shot-image-classification", + ) + return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response) + + def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str: + model = model or self.model or self.base_url + + # If model is already a URL, ignore `task` and return directly + if model is not None and (model.startswith("http://") or model.startswith("https://")): + return model + + # # If no model but task is set => fetch the recommended one for this task + if model is None: + if task is None: + raise ValueError( + "You must specify at least a model (repo_id or URL) or a task, either when instantiating" + " `InferenceClient` or when making a request." + ) + model = self.get_recommended_model(task) + logger.info( + f"Using recommended model {model} for task {task}. Note that it is" + f" encouraged to explicitly set `model='{model}'` as the recommended" + " models list might get updated without prior notice." + ) + + # Compute InferenceAPI url + return ( + # Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks. + f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}" + if task in ("feature-extraction", "sentence-similarity") + # Otherwise, we use the default endpoint + else f"{INFERENCE_ENDPOINT}/models/{model}" + ) + + @staticmethod + def get_recommended_model(task: str) -> str: + """ + Get the model Hugging Face recommends for the input task. + + Args: + task (`str`): + The Hugging Face task to get which model Hugging Face recommends. + All available tasks can be found [here](https://huggingface.co/tasks). + + Returns: + `str`: Name of the model recommended for the input task. + + Raises: + `ValueError`: If Hugging Face has no recommendation for the input task. + """ + model = _fetch_recommended_models().get(task) + if model is None: + raise ValueError( + f"Task {task} has no recommended model. Please specify a model" + " explicitly. Visit https://huggingface.co/tasks for more info." + ) + return model + + def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]: + """ + Get information about the deployed endpoint. + + This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). + Endpoints powered by `transformers` return an empty payload. + + Args: + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `Dict[str, Any]`: Information about the endpoint. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> client.get_endpoint_info() + { + 'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct', + 'model_sha': None, + 'model_dtype': 'torch.float16', + 'model_device_type': 'cuda', + 'model_pipeline_tag': None, + 'max_concurrent_requests': 128, + 'max_best_of': 2, + 'max_stop_sequences': 4, + 'max_input_length': 8191, + 'max_total_tokens': 8192, + 'waiting_served_ratio': 0.3, + 'max_batch_total_tokens': 1259392, + 'max_waiting_tokens': 20, + 'max_batch_size': None, + 'validation_workers': 32, + 'max_client_batch_size': 4, + 'version': '2.0.2', + 'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214', + 'docker_label': 'sha-dccab72' + } + ``` + """ + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if model.startswith(("http://", "https://")): + url = model.rstrip("/") + "/info" + else: + url = f"{INFERENCE_ENDPOINT}/models/{model}/info" + + response = get_session().get(url, headers=self.headers) + hf_raise_for_status(response) + return response.json() + + def health_check(self, model: Optional[str] = None) -> bool: + """ + Check the health of the deployed endpoint. + + Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). + For Inference API, please use [`InferenceClient.get_model_status`] instead. + + Args: + model (`str`, *optional*): + URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `bool`: True if everything is working fine. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud") + >>> client.health_check() + True + ``` + """ + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if not model.startswith(("http://", "https://")): + raise ValueError( + "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`." + ) + url = model.rstrip("/") + "/health" + + response = get_session().get(url, headers=self.headers) + return response.status_code == 200 + + def get_model_status(self, model: Optional[str] = None) -> ModelStatus: + """ + Get the status of a model hosted on the Inference API. + ++ + This endpoint is mostly useful when you already know which model you want to use and want to check its + availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`]. + + + + Args: + model (`str`, *optional*): + Identifier of the model for witch the status gonna be checked. If model is not provided, + the model associated with this instance of [`InferenceClient`] will be used. Only InferenceAPI service can be checked so the + identifier cannot be a URL. + + + Returns: + [`ModelStatus`]: An instance of ModelStatus dataclass, containing information, + about the state of the model: load, state, compute type and framework. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct") + ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference') + ``` + """ + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if model.startswith("https://"): + raise NotImplementedError("Model status is only available for Inference API endpoints.") + url = f"{INFERENCE_ENDPOINT}/status/{model}" + + response = get_session().get(url, headers=self.headers) + hf_raise_for_status(response) + response_data = response.json() + + if "error" in response_data: + raise ValueError(response_data["error"]) + + return ModelStatus( + loaded=response_data["loaded"], + state=response_data["state"], + compute_type=response_data["compute_type"], + framework=response_data["framework"], + ) + + @property + def chat(self) -> "ProxyClientChat": + return ProxyClientChat(self) + + +class _ProxyClient: + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + def __init__(self, client: InferenceClient): + self._client = client + + +class ProxyClientChat(_ProxyClient): + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + @property + def completions(self) -> "ProxyClientChatCompletions": + return ProxyClientChatCompletions(self._client) + + +class ProxyClientChatCompletions(_ProxyClient): + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + @property + def create(self): + return self._client.chat_completion diff --git a/huggingface_hub/inference/_common.py b/huggingface_hub/inference/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..a92d8fad4a475e05323885864f82519c81f7e846 --- /dev/null +++ b/huggingface_hub/inference/_common.py @@ -0,0 +1,440 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities used by both the sync and async inference clients.""" + +import base64 +import io +import json +import logging +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + BinaryIO, + ContextManager, + Dict, + Generator, + Iterable, + List, + Literal, + NoReturn, + Optional, + Union, + overload, +) + +from requests import HTTPError + +from huggingface_hub.errors import ( + GenerationError, + IncompleteGenerationError, + OverloadedError, + TextGenerationError, + UnknownError, + ValidationError, +) + +from ..constants import ENDPOINT +from ..utils import ( + build_hf_headers, + get_session, + hf_raise_for_status, + is_aiohttp_available, + is_numpy_available, + is_pillow_available, +) +from ._generated.types import ( + ChatCompletionStreamOutput, + TextGenerationStreamOutput, +) + + +if TYPE_CHECKING: + from aiohttp import ClientResponse, ClientSession + from PIL.Image import Image + +# TYPES +UrlT = str +PathT = Union[str, Path] +BinaryT = Union[bytes, BinaryIO] +ContentT = Union[BinaryT, PathT, UrlT] + +# Use to set a Accept: image/png header +TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"} + +logger = logging.getLogger(__name__) + + +# Add dataclass for ModelStatus. We use this dataclass in get_model_status function. +@dataclass +class ModelStatus: + """ + This Dataclass represents the the model status in the Hugging Face Inference API. + + Args: + loaded (`bool`): + If the model is currently loaded into Hugging Face's InferenceAPI. Models + are loaded on-demand, leading to the user's first request taking longer. + If a model is loaded, you can be assured that it is in a healthy state. + state (`str`): + The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'. + If a model's state is 'Loadable', it's not too big and has a supported + backend. Loadable models are automatically loaded when the user first + requests inference on the endpoint. This means it is transparent for the + user to load a model, except that the first call takes longer to complete. + compute_type (`Dict`): + Information about the compute resource the model is using or will use, such as 'gpu' type and number of + replicas. + framework (`str`): + The name of the framework that the model was built with, such as 'transformers' + or 'text-generation-inference'. + """ + + loaded: bool + state: str + compute_type: Dict + framework: str + + +## IMPORT UTILS + + +def _import_aiohttp(): + # Make sure `aiohttp` is installed on the machine. + if not is_aiohttp_available(): + raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).") + import aiohttp + + return aiohttp + + +def _import_numpy(): + """Make sure `numpy` is installed on the machine.""" + if not is_numpy_available(): + raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).") + import numpy + + return numpy + + +def _import_pil_image(): + """Make sure `PIL` is installed on the machine.""" + if not is_pillow_available(): + raise ImportError( + "Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be" + " post-processed, use `client.post(...)` and get the raw response from the server." + ) + from PIL import Image + + return Image + + +## RECOMMENDED MODELS + +# Will be globally fetched only once (see '_fetch_recommended_models') +_RECOMMENDED_MODELS: Optional[Dict[str, Optional[str]]] = None + + +def _fetch_recommended_models() -> Dict[str, Optional[str]]: + global _RECOMMENDED_MODELS + if _RECOMMENDED_MODELS is None: + response = get_session().get(f"{ENDPOINT}/api/tasks", headers=build_hf_headers()) + hf_raise_for_status(response) + _RECOMMENDED_MODELS = { + task: _first_or_none(details["widgetModels"]) for task, details in response.json().items() + } + return _RECOMMENDED_MODELS + + +def _first_or_none(items: List[Any]) -> Optional[Any]: + try: + return items[0] or None + except IndexError: + return None + + +## ENCODING / DECODING UTILS + + +@overload +def _open_as_binary( + content: ContentT, +) -> ContextManager[BinaryT]: ... # means "if input is not None, output is not None" + + +@overload +def _open_as_binary( + content: Literal[None], +) -> ContextManager[Literal[None]]: ... # means "if input is None, output is None" + + +@contextmanager # type: ignore +def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]: + """Open `content` as a binary file, either from a URL, a local path, or raw bytes. + + Do nothing if `content` is None, + + TODO: handle a PIL.Image as input + TODO: handle base64 as input + """ + # If content is a string => must be either a URL or a path + if isinstance(content, str): + if content.startswith("https://") or content.startswith("http://"): + logger.debug(f"Downloading content from {content}") + yield get_session().get(content).content # TODO: retrieve as stream and pipe to post request ? + return + content = Path(content) + if not content.exists(): + raise FileNotFoundError( + f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local" + " file. To pass raw content, please encode it as bytes first." + ) + + # If content is a Path => open it + if isinstance(content, Path): + logger.debug(f"Opening content from {content}") + with content.open("rb") as f: + yield f + else: + # Otherwise: already a file-like object or None + yield content + + +def _b64_encode(content: ContentT) -> str: + """Encode a raw file (image, audio) into base64. Can be byes, an opened file, a path or a URL.""" + with _open_as_binary(content) as data: + data_as_bytes = data if isinstance(data, bytes) else data.read() + return base64.b64encode(data_as_bytes).decode() + + +def _b64_to_image(encoded_image: str) -> "Image": + """Parse a base64-encoded string into a PIL Image.""" + Image = _import_pil_image() + return Image.open(io.BytesIO(base64.b64decode(encoded_image))) + + +def _bytes_to_list(content: bytes) -> List: + """Parse bytes from a Response object into a Python list. + + Expects the response body to be JSON-encoded data. + + NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a + dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. + """ + return json.loads(content.decode()) + + +def _bytes_to_dict(content: bytes) -> Dict: + """Parse bytes from a Response object into a Python dictionary. + + Expects the response body to be JSON-encoded data. + + NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a + list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. + """ + return json.loads(content.decode()) + + +def _bytes_to_image(content: bytes) -> "Image": + """Parse bytes from a Response object into a PIL Image. + + Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead. + """ + Image = _import_pil_image() + return Image.open(io.BytesIO(content)) + + +## STREAMING UTILS + + +def _stream_text_generation_response( + bytes_output_as_lines: Iterable[bytes], details: bool +) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]: + """Used in `InferenceClient.text_generation`.""" + # Parse ServerSentEvents + for byte_payload in bytes_output_as_lines: + try: + output = _format_text_generation_stream_output(byte_payload, details) + except StopIteration: + break + if output is not None: + yield output + + +async def _async_stream_text_generation_response( + bytes_output_as_lines: AsyncIterable[bytes], details: bool +) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: + """Used in `AsyncInferenceClient.text_generation`.""" + # Parse ServerSentEvents + async for byte_payload in bytes_output_as_lines: + try: + output = _format_text_generation_stream_output(byte_payload, details) + except StopIteration: + break + if output is not None: + yield output + + +def _format_text_generation_stream_output( + byte_payload: bytes, details: bool +) -> Optional[Union[str, TextGenerationStreamOutput]]: + if not byte_payload.startswith(b"data:"): + return None # empty line + + if byte_payload.strip() == b"data: [DONE]": + raise StopIteration("[DONE] signal received.") + + # Decode payload + payload = byte_payload.decode("utf-8") + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + + # Either an error as being returned + if json_payload.get("error") is not None: + raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) + + # Or parse token payload + output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload) + return output.token.text if not details else output + + +def _stream_chat_completion_response( + bytes_lines: Iterable[bytes], +) -> Iterable[ChatCompletionStreamOutput]: + """Used in `InferenceClient.chat_completion` if model is served with TGI.""" + for item in bytes_lines: + try: + output = _format_chat_completion_stream_output(item) + except StopIteration: + break + if output is not None: + yield output + + +async def _async_stream_chat_completion_response( + bytes_lines: AsyncIterable[bytes], +) -> AsyncIterable[ChatCompletionStreamOutput]: + """Used in `AsyncInferenceClient.chat_completion`.""" + async for item in bytes_lines: + try: + output = _format_chat_completion_stream_output(item) + except StopIteration: + break + if output is not None: + yield output + + +def _format_chat_completion_stream_output( + byte_payload: bytes, +) -> Optional[ChatCompletionStreamOutput]: + if not byte_payload.startswith(b"data:"): + return None # empty line + + if byte_payload.strip() == b"data: [DONE]": + raise StopIteration("[DONE] signal received.") + + # Decode payload + payload = byte_payload.decode("utf-8") + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + + # Either an error as being returned + if json_payload.get("error") is not None: + raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) + + # Or parse token payload + return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload) + + +async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]: + async for byte_payload in response.content: + yield byte_payload.strip() + await client.close() + + +# "TGI servers" are servers running with the `text-generation-inference` backend. +# This backend is the go-to solution to run large language models at scale. However, +# for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference` +# solution is still in use. +# +# Both approaches have very similar APIs, but not exactly the same. What we do first in +# the `text_generation` method is to assume the model is served via TGI. If we realize +# it's not the case (i.e. we receive an HTTP 400 Bad Request), we fallback to the +# default API with a warning message. When that's the case, We remember the unsupported +# attributes for this model in the `_UNSUPPORTED_TEXT_GENERATION_KWARGS` global variable. +# +# In addition, TGI servers have a built-in API route for chat-completion, which is not +# available on the default API. We use this route to provide a more consistent behavior +# when available. +# +# For more details, see https://github.com/huggingface/text-generation-inference and +# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task. + +_UNSUPPORTED_TEXT_GENERATION_KWARGS: Dict[Optional[str], List[str]] = {} + + +def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: List[str]) -> None: + _UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs) + + +def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]: + return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, []) + + +# TEXT GENERATION ERRORS +# ---------------------- +# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation +# inference project (https://github.com/huggingface/text-generation-inference). +# ---------------------- + + +def raise_text_generation_error(http_error: HTTPError) -> NoReturn: + """ + Try to parse text-generation-inference error message and raise HTTPError in any case. + + Args: + error (`HTTPError`): + The HTTPError that have been raised. + """ + # Try to parse a Text Generation Inference error + + try: + # Hacky way to retrieve payload in case of aiohttp error + payload = getattr(http_error, "response_error_payload", None) or http_error.response.json() + error = payload.get("error") + error_type = payload.get("error_type") + except Exception: # no payload + raise http_error + + # If error_type => more information than `hf_raise_for_status` + if error_type is not None: + exception = _parse_text_generation_error(error, error_type) + raise exception from http_error + + # Otherwise, fallback to default error + raise http_error + + +def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError: + if error_type == "generation": + return GenerationError(error) # type: ignore + if error_type == "incomplete_generation": + return IncompleteGenerationError(error) # type: ignore + if error_type == "overloaded": + return OverloadedError(error) # type: ignore + if error_type == "validation": + return ValidationError(error) # type: ignore + return UnknownError(error) # type: ignore diff --git a/huggingface_hub/inference/_generated/__init__.py b/huggingface_hub/inference/_generated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/huggingface_hub/inference/_generated/_async_client.py b/huggingface_hub/inference/_generated/_async_client.py new file mode 100644 index 0000000000000000000000000000000000000000..095cb376a64184f941e014ad924cedf0f87d9cf0 --- /dev/null +++ b/huggingface_hub/inference/_generated/_async_client.py @@ -0,0 +1,2908 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# WARNING +# This entire file has been adapted from the sync-client code in `src/huggingface_hub/inference/_client.py`. +# Any change in InferenceClient will be automatically reflected in AsyncInferenceClient. +# To re-generate the code, run `make style` or `python ./utils/generate_async_inference_client.py --update`. +# WARNING +import asyncio +import base64 +import logging +import re +import time +import warnings +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + Dict, + List, + Literal, + Optional, + Set, + Union, + overload, +) + +from requests.structures import CaseInsensitiveDict + +from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS +from huggingface_hub.errors import InferenceTimeoutError +from huggingface_hub.inference._common import ( + TASKS_EXPECTING_IMAGES, + ContentT, + ModelStatus, + _async_stream_chat_completion_response, + _async_stream_text_generation_response, + _b64_encode, + _b64_to_image, + _bytes_to_dict, + _bytes_to_image, + _bytes_to_list, + _fetch_recommended_models, + _get_unsupported_text_generation_kwargs, + _import_numpy, + _open_as_binary, + _set_unsupported_text_generation_kwargs, + raise_text_generation_error, +) +from huggingface_hub.inference._generated.types import ( + AudioClassificationOutputElement, + AudioToAudioOutputElement, + AutomaticSpeechRecognitionOutput, + ChatCompletionInputGrammarType, + ChatCompletionInputTool, + ChatCompletionInputToolTypeClass, + ChatCompletionOutput, + ChatCompletionStreamOutput, + DocumentQuestionAnsweringOutputElement, + FillMaskOutputElement, + ImageClassificationOutputElement, + ImageSegmentationOutputElement, + ImageToTextOutput, + ObjectDetectionOutputElement, + QuestionAnsweringOutputElement, + SummarizationOutput, + TableQuestionAnsweringOutputElement, + TextClassificationOutputElement, + TextGenerationInputGrammarType, + TextGenerationOutput, + TextGenerationStreamOutput, + TokenClassificationOutputElement, + TranslationOutput, + VisualQuestionAnsweringOutputElement, + ZeroShotClassificationOutputElement, + ZeroShotImageClassificationOutputElement, +) +from huggingface_hub.utils import ( + build_hf_headers, +) +from huggingface_hub.utils._deprecation import _deprecate_positional_args + +from .._common import _async_yield_from, _import_aiohttp + + +if TYPE_CHECKING: + import numpy as np + from aiohttp import ClientResponse, ClientSession + from PIL.Image import Image + +logger = logging.getLogger(__name__) + + +MODEL_KWARGS_NOT_USED_REGEX = re.compile(r"The following `model_kwargs` are not used by the model: \[(.*?)\]") + + +class AsyncInferenceClient: + """ + Initialize a new Inference Client. + + [`InferenceClient`] aims to provide a unified experience to perform inference. The client can be used + seamlessly with either the (free) Inference API or self-hosted Inference Endpoints. + + Args: + model (`str`, `optional`): + The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct` + or a URL to a deployed Inference Endpoint. Defaults to None, in which case a recommended model is + automatically selected for the task. + Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2 + arguments are mutually exclusive. If using `base_url` for chat completion, the `/chat/completions` suffix + path will be appended to the base URL (see the [TGI Messages API](https://huggingface.co/docs/text-generation-inference/en/messages_api) + documentation for details). When passing a URL as `model`, the client will not append any suffix path to it. + token (`str` or `bool`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + Pass `token=False` if you don't want to send your token to the server. + Note: for better compatibility with OpenAI's client, `token` has been aliased as `api_key`. Those 2 + arguments are mutually exclusive and have the exact same behavior. + timeout (`float`, `optional`): + The maximum number of seconds to wait for a response from the server. Loading a new model in Inference + API can take up to several minutes. Defaults to None, meaning it will loop until the server is available. + headers (`Dict[str, str]`, `optional`): + Additional headers to send to the server. By default only the authorization and user-agent headers are sent. + Values in this dictionary will override the default values. + cookies (`Dict[str, str]`, `optional`): + Additional cookies to send to the server. + trust_env ('bool', 'optional'): + Trust environment settings for proxy configuration if the parameter is `True` (`False` by default). + proxies (`Any`, `optional`): + Proxies to use for the request. + base_url (`str`, `optional`): + Base URL to run inference. This is a duplicated argument from `model` to make [`InferenceClient`] + follow the same pattern as `openai.OpenAI` client. Cannot be used if `model` is set. Defaults to None. + api_key (`str`, `optional`): + Token to use for authentication. This is a duplicated argument from `token` to make [`InferenceClient`] + follow the same pattern as `openai.OpenAI` client. Cannot be used if `token` is set. Defaults to None. + """ + + @_deprecate_positional_args(version="0.26") + def __init__( + self, + model: Optional[str] = None, + *, + token: Union[str, bool, None] = None, + timeout: Optional[float] = None, + headers: Optional[Dict[str, str]] = None, + cookies: Optional[Dict[str, str]] = None, + trust_env: bool = False, + proxies: Optional[Any] = None, + # OpenAI compatibility + base_url: Optional[str] = None, + api_key: Optional[str] = None, + ) -> None: + if model is not None and base_url is not None: + raise ValueError( + "Received both `model` and `base_url` arguments. Please provide only one of them." + " `base_url` is an alias for `model` to make the API compatible with OpenAI's client." + " If using `base_url` for chat completion, the `/chat/completions` suffix path will be appended to the base url." + " When passing a URL as `model`, the client will not append any suffix path to it." + ) + if token is not None and api_key is not None: + raise ValueError( + "Received both `token` and `api_key` arguments. Please provide only one of them." + " `api_key` is an alias for `token` to make the API compatible with OpenAI's client." + " It has the exact same behavior as `token`." + ) + + self.model: Optional[str] = model + self.token: Union[str, bool, None] = token if token is not None else api_key + self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent' + if headers is not None: + self.headers.update(headers) + self.cookies = cookies + self.timeout = timeout + self.trust_env = trust_env + self.proxies = proxies + + # OpenAI compatibility + self.base_url = base_url + + # Keep track of the sessions to close them properly + self._sessions: Dict["ClientSession", Set["ClientResponse"]] = dict() + + def __repr__(self): + return f"" + + @overload + async def post( # type: ignore[misc] + self, + *, + json: Optional[Union[str, Dict, List]] = None, + data: Optional[ContentT] = None, + model: Optional[str] = None, + task: Optional[str] = None, + stream: Literal[False] = ..., + ) -> bytes: ... + + @overload + async def post( # type: ignore[misc] + self, + *, + json: Optional[Union[str, Dict, List]] = None, + data: Optional[ContentT] = None, + model: Optional[str] = None, + task: Optional[str] = None, + stream: Literal[True] = ..., + ) -> AsyncIterable[bytes]: ... + + @overload + async def post( + self, + *, + json: Optional[Union[str, Dict, List]] = None, + data: Optional[ContentT] = None, + model: Optional[str] = None, + task: Optional[str] = None, + stream: bool = False, + ) -> Union[bytes, AsyncIterable[bytes]]: ... + + async def post( + self, + *, + json: Optional[Union[str, Dict, List]] = None, + data: Optional[ContentT] = None, + model: Optional[str] = None, + task: Optional[str] = None, + stream: bool = False, + ) -> Union[bytes, AsyncIterable[bytes]]: + """ + Make a POST request to the inference server. + + Args: + json (`Union[str, Dict, List]`, *optional*): + The JSON data to send in the request body, specific to each task. Defaults to None. + data (`Union[str, Path, bytes, BinaryIO]`, *optional*): + The content to send in the request body, specific to each task. + It can be raw bytes, a pointer to an opened file, a local file path, + or a URL to an online resource (image, audio file,...). If both `json` and `data` are passed, + `data` will take precedence. At least `json` or `data` must be provided. Defaults to None. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. Will override the model defined at the instance level. Defaults to None. + task (`str`, *optional*): + The task to perform on the inference. All available tasks can be found + [here](https://huggingface.co/tasks). Used only to default to a recommended model if `model` is not + provided. At least `model` or `task` must be provided. Defaults to None. + stream (`bool`, *optional*): + Whether to iterate over streaming APIs. + + Returns: + bytes: The raw bytes returned by the server. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + """ + + aiohttp = _import_aiohttp() + + url = self._resolve_url(model, task) + + if data is not None and json is not None: + warnings.warn("Ignoring `json` as `data` is passed as binary.") + + # Set Accept header if relevant + headers = dict() + if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers: + headers["Accept"] = "image/png" + + t0 = time.time() + timeout = self.timeout + while True: + with _open_as_binary(data) as data_as_binary: + # Do not use context manager as we don't want to close the connection immediately when returning + # a stream + session = self._get_client_session(headers=headers) + + try: + response = await session.post(url, json=json, data=data_as_binary, proxy=self.proxies) + response_error_payload = None + if response.status != 200: + try: + response_error_payload = await response.json() # get payload before connection closed + except Exception: + pass + response.raise_for_status() + if stream: + return _async_yield_from(session, response) + else: + content = await response.read() + await session.close() + return content + except asyncio.TimeoutError as error: + await session.close() + # Convert any `TimeoutError` to a `InferenceTimeoutError` + raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore + except aiohttp.ClientResponseError as error: + error.response_error_payload = response_error_payload + await session.close() + if response.status == 422 and task is not None: + error.message += f". Make sure '{task}' task is supported by the model." + if response.status == 503: + # If Model is unavailable, either raise a TimeoutError... + if timeout is not None and time.time() - t0 > timeout: + raise InferenceTimeoutError( + f"Model not loaded on the server: {url}. Please retry with a higher timeout" + f" (current: {self.timeout}).", + request=error.request, + response=error.response, + ) from error + # ...or wait 1s and retry + logger.info(f"Waiting for model to be loaded on the server: {error}") + if "X-wait-for-model" not in headers and url.startswith(INFERENCE_ENDPOINT): + headers["X-wait-for-model"] = "1" + time.sleep(1) + if timeout is not None: + timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore + continue + raise error + except Exception: + await session.close() + raise + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + def __del__(self): + if len(self._sessions) > 0: + warnings.warn( + "Deleting 'AsyncInferenceClient' client but some sessions are still open. " + "This can happen if you've stopped streaming data from the server before the stream was complete. " + "To close the client properly, you must call `await client.close()` " + "or use an async context (e.g. `async with AsyncInferenceClient(): ...`." + ) + + async def close(self): + """Close all open sessions. + + By default, 'aiohttp.ClientSession' objects are closed automatically when a call is completed. However, if you + are streaming data from the server and you stop before the stream is complete, you must call this method to + close the session properly. + + Another possibility is to use an async context (e.g. `async with AsyncInferenceClient(): ...`). + """ + await asyncio.gather(*[session.close() for session in self._sessions.keys()]) + + async def audio_classification( + self, + audio: ContentT, + *, + model: Optional[str] = None, + ) -> List[AudioClassificationOutputElement]: + """ + Perform audio classification on the provided audio content. + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The audio content to classify. It can be raw audio bytes, a local audio file, or a URL pointing to an + audio file. + model (`str`, *optional*): + The model to use for audio classification. Can be a model ID hosted on the Hugging Face Hub + or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for + audio classification will be used. + + Returns: + `List[AudioClassificationOutputElement]`: List of [`AudioClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.audio_classification("audio.flac") + [ + AudioClassificationOutputElement(score=0.4976358711719513, label='hap'), + AudioClassificationOutputElement(score=0.3677836060523987, label='neu'), + ... + ] + ``` + """ + response = await self.post(data=audio, model=model, task="audio-classification") + return AudioClassificationOutputElement.parse_obj_as_list(response) + + async def audio_to_audio( + self, + audio: ContentT, + *, + model: Optional[str] = None, + ) -> List[AudioToAudioOutputElement]: + """ + Performs multiple tasks related to audio-to-audio depending on the model (eg: speech enhancement, source separation). + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The audio content for the model. It can be raw audio bytes, a local audio file, or a URL pointing to an + audio file. + model (`str`, *optional*): + The model can be any model which takes an audio file and returns another audio file. Can be a model ID hosted on the Hugging Face Hub + or a URL to a deployed Inference Endpoint. If not provided, the default recommended model for + audio_to_audio will be used. + + Returns: + `List[AudioToAudioOutputElement]`: A list of [`AudioToAudioOutputElement`] items containing audios label, content-type, and audio content in blob. + + Raises: + `InferenceTimeoutError`: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> audio_output = await client.audio_to_audio("audio.flac") + >>> async for i, item in enumerate(audio_output): + >>> with open(f"output_{i}.flac", "wb") as f: + f.write(item.blob) + ``` + """ + response = await self.post(data=audio, model=model, task="audio-to-audio") + audio_output = AudioToAudioOutputElement.parse_obj_as_list(response) + for item in audio_output: + item.blob = base64.b64decode(item.blob) + return audio_output + + async def automatic_speech_recognition( + self, + audio: ContentT, + *, + model: Optional[str] = None, + ) -> AutomaticSpeechRecognitionOutput: + """ + Perform automatic speech recognition (ASR or audio-to-text) on the given audio content. + + Args: + audio (Union[str, Path, bytes, BinaryIO]): + The content to transcribe. It can be raw audio bytes, local audio file, or a URL to an audio file. + model (`str`, *optional*): + The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended model for ASR will be used. + + Returns: + [`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.automatic_speech_recognition("hello_world.flac").text + "hello world" + ``` + """ + response = await self.post(data=audio, model=model, task="automatic-speech-recognition") + return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response) + + @overload + async def chat_completion( # type: ignore + self, + messages: List[Dict[str, str]], + *, + model: Optional[str] = None, + stream: Literal[False] = False, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + ) -> ChatCompletionOutput: ... + + @overload + async def chat_completion( # type: ignore + self, + messages: List[Dict[str, str]], + *, + model: Optional[str] = None, + stream: Literal[True] = True, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + ) -> AsyncIterable[ChatCompletionStreamOutput]: ... + + @overload + async def chat_completion( + self, + messages: List[Dict[str, str]], + *, + model: Optional[str] = None, + stream: bool = False, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: ... + + async def chat_completion( + self, + messages: List[Dict[str, str]], + *, + model: Optional[str] = None, + stream: bool = False, + # Parameters from ChatCompletionInput (handled manually) + frequency_penalty: Optional[float] = None, + logit_bias: Optional[List[float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[ChatCompletionInputGrammarType] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None, + tool_prompt: Optional[str] = None, + tools: Optional[List[ChatCompletionInputTool]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + ) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: + """ + A method for completing conversations using a specified language model. + + + + The `client.chat_completion` method is aliased as `client.chat.completions.create` for compatibility with OpenAI's client. + Inputs and outputs are strictly the same and using either syntax will yield the same results. + Check out the [Inference guide](https://huggingface.co/docs/huggingface_hub/guides/inference#openai-compatibility) + for more details about OpenAI's compatibility. + + + + Args: + messages (List[Union[`SystemMessage`, `UserMessage`, `AssistantMessage`]]): + Conversation history consisting of roles and content pairs. + model (`str`, *optional*): + The model to use for chat-completion. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. If not provided, the default recommended model for chat-based text-generation will be used. + See https://huggingface.co/tasks/text-generation for more details. + + If `model` is a model ID, it is passed to the server as the `model` parameter. If you want to define a + custom URL while setting `model` in the request payload, you must set `base_url` when initializing [`InferenceClient`]. + frequency_penalty (`float`, *optional*): + Penalizes new tokens based on their existing frequency + in the text so far. Range: [-2.0, 2.0]. Defaults to 0.0. + logit_bias (`List[float]`, *optional*): + Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens + (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, + the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, + but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should + result in a ban or exclusive selection of the relevant token. Defaults to None. + logprobs (`bool`, *optional*): + Whether to return log probabilities of the output tokens or not. If true, returns the log + probabilities of each output token returned in the content of message. + max_tokens (`int`, *optional*): + Maximum number of tokens allowed in the response. Defaults to 20. + n (`int`, *optional*): + UNUSED. + presence_penalty (`float`, *optional*): + Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the + text so far, increasing the model's likelihood to talk about new topics. + response_format ([`ChatCompletionInputGrammarType`], *optional*): + Grammar constraints. Can be either a JSONSchema or a regex. + seed (Optional[`int`], *optional*): + Seed for reproducible control flow. Defaults to None. + stop (Optional[`str`], *optional*): + Up to four strings which trigger the end of the response. + Defaults to None. + stream (`bool`, *optional*): + Enable realtime streaming of responses. Defaults to False. + temperature (`float`, *optional*): + Controls randomness of the generations. Lower values ensure + less random completions. Range: [0, 2]. Defaults to 1.0. + top_logprobs (`int`, *optional*): + An integer between 0 and 5 specifying the number of most likely tokens to return at each token + position, each with an associated log probability. logprobs must be set to true if this parameter is + used. + top_p (`float`, *optional*): + Fraction of the most likely next words to sample from. + Must be between 0 and 1. Defaults to 1.0. + tool_choice ([`ChatCompletionInputToolTypeClass`] or `str`, *optional*): + The tool to use for the completion. Defaults to "auto". + tool_prompt (`str`, *optional*): + A prompt to be appended before the tools. + tools (List of [`ChatCompletionInputTool`], *optional*): + A list of tools the model may call. Currently, only functions are supported as a tool. Use this to + provide a list of functions the model may generate JSON inputs for. + + Returns: + [`ChatCompletionOutput`] or Iterable of [`ChatCompletionStreamOutput`]: + Generated text returned from the server: + - if `stream=False`, the generated text is returned as a [`ChatCompletionOutput`] (default). + - if `stream=True`, the generated text is returned token by token as a sequence of [`ChatCompletionStreamOutput`]. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> messages = [{"role": "user", "content": "What is the capital of France?"}] + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") + >>> await client.chat_completion(messages, max_tokens=100) + ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason='eos_token', + index=0, + message=ChatCompletionOutputMessage( + role='assistant', + content='The capital of France is Paris.', + name=None, + tool_calls=None + ), + logprobs=None + ) + ], + created=1719907176, + id='', + model='meta-llama/Meta-Llama-3-8B-Instruct', + object='text_completion', + system_fingerprint='2.0.4-sha-f426a33', + usage=ChatCompletionOutputUsage( + completion_tokens=8, + prompt_tokens=17, + total_tokens=25 + ) + ) + ``` + + Example (stream=True): + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> messages = [{"role": "user", "content": "What is the capital of France?"}] + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-8B-Instruct") + >>> async for token in await client.chat_completion(messages, max_tokens=10, stream=True): + ... print(token) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content='The', role='assistant'), index=0, finish_reason=None)], created=1710498504) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' capital', role='assistant'), index=0, finish_reason=None)], created=1710498504) + (...) + ChatCompletionStreamOutput(choices=[ChatCompletionStreamOutputChoice(delta=ChatCompletionStreamOutputDelta(content=' may', role='assistant'), index=0, finish_reason=None)], created=1710498504) + ``` + + Example using OpenAI's syntax: + ```py + # Must be run in an async context + # instead of `from openai import OpenAI` + from huggingface_hub import AsyncInferenceClient + + # instead of `client = OpenAI(...)` + client = AsyncInferenceClient( + base_url=..., + api_key=..., + ) + + output = await client.chat.completions.create( + model="meta-llama/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Count to 10"}, + ], + stream=True, + max_tokens=1024, + ) + + for chunk in output: + print(chunk.choices[0].delta.content) + ``` + + Example using tools: + ```py + # Must be run in an async context + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "system", + ... "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.", + ... }, + ... { + ... "role": "user", + ... "content": "What's the weather like the next 3 days in San Francisco, CA?", + ... }, + ... ] + >>> tools = [ + ... { + ... "type": "function", + ... "function": { + ... "name": "get_current_weather", + ... "description": "Get the current weather", + ... "parameters": { + ... "type": "object", + ... "properties": { + ... "location": { + ... "type": "string", + ... "description": "The city and state, e.g. San Francisco, CA", + ... }, + ... "format": { + ... "type": "string", + ... "enum": ["celsius", "fahrenheit"], + ... "description": "The temperature unit to use. Infer this from the users location.", + ... }, + ... }, + ... "required": ["location", "format"], + ... }, + ... }, + ... }, + ... { + ... "type": "function", + ... "function": { + ... "name": "get_n_day_weather_forecast", + ... "description": "Get an N-day weather forecast", + ... "parameters": { + ... "type": "object", + ... "properties": { + ... "location": { + ... "type": "string", + ... "description": "The city and state, e.g. San Francisco, CA", + ... }, + ... "format": { + ... "type": "string", + ... "enum": ["celsius", "fahrenheit"], + ... "description": "The temperature unit to use. Infer this from the users location.", + ... }, + ... "num_days": { + ... "type": "integer", + ... "description": "The number of days to forecast", + ... }, + ... }, + ... "required": ["location", "format", "num_days"], + ... }, + ... }, + ... }, + ... ] + + >>> response = await client.chat_completion( + ... model="meta-llama/Meta-Llama-3-70B-Instruct", + ... messages=messages, + ... tools=tools, + ... tool_choice="auto", + ... max_tokens=500, + ... ) + >>> response.choices[0].message.tool_calls[0].function + ChatCompletionOutputFunctionDefinition( + arguments={ + 'location': 'San Francisco, CA', + 'format': 'fahrenheit', + 'num_days': 3 + }, + name='get_n_day_weather_forecast', + description=None + ) + ``` + + Example using response_format: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> messages = [ + ... { + ... "role": "user", + ... "content": "I saw a puppy a cat and a raccoon during my bike ride in the park. What did I saw and when?", + ... }, + ... ] + >>> response_format = { + ... "type": "json", + ... "value": { + ... "properties": { + ... "location": {"type": "string"}, + ... "activity": {"type": "string"}, + ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, + ... "animals": {"type": "array", "items": {"type": "string"}}, + ... }, + ... "required": ["location", "activity", "animals_seen", "animals"], + ... }, + ... } + >>> response = await client.chat_completion( + ... messages=messages, + ... response_format=response_format, + ... max_tokens=500, + ) + >>> response.choices[0].message.content + '{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}' + ``` + """ + model_url = self._resolve_chat_completion_url(model) + + # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing. + # If it's a ID on the Hub => use it. Otherwise, we use a random string. + model_id = model or self.model or "tgi" + if model_id.startswith(("http://", "https://")): + model_id = "tgi" # dummy value + + payload = dict( + model=model_id, + messages=messages, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + temperature=temperature, + tool_choice=tool_choice, + tool_prompt=tool_prompt, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + stream=stream, + ) + payload = {key: value for key, value in payload.items() if value is not None} + data = await self.post(model=model_url, json=payload, stream=stream) + + if stream: + return _async_stream_chat_completion_response(data) # type: ignore[arg-type] + + return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type] + + def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str: + # Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently. + # `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`. + model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation") + + # Resolve URL if it's a model ID + model_url = ( + model_id_or_url + if model_id_or_url.startswith(("http://", "https://")) + else self._resolve_url(model_id_or_url, task="text-generation") + ) + + # Strip trailing / + model_url = model_url.rstrip("/") + + # Append /chat/completions if not already present + if model_url.endswith("/v1"): + model_url += "/chat/completions" + + # Append /v1/chat/completions if not already present + if not model_url.endswith("/chat/completions"): + model_url += "/v1/chat/completions" + + return model_url + + async def document_question_answering( + self, + image: ContentT, + question: str, + *, + model: Optional[str] = None, + ) -> List[DocumentQuestionAnsweringOutputElement]: + """ + Answer questions on document images. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image for the context. It can be raw bytes, an image file, or a URL to an online image. + question (`str`): + Question to be answered. + model (`str`, *optional*): + The model to use for the document question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended document question answering model will be used. + Defaults to None. + + Returns: + `List[DocumentQuestionAnsweringOutputElement]`: a list of [`DocumentQuestionAnsweringOutputElement`] items containing the predicted label, associated probability, word ids, and page number. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.document_question_answering(image="https://huggingface.co/spaces/impira/docquery/resolve/2359223c1837a7587402bda0f2643382a6eefeab/invoice.png", question="What is the invoice number?") + [DocumentQuestionAnsweringOutputElement(score=0.42515629529953003, answer='us-001', start=16, end=16)] + ``` + """ + payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + response = await self.post(json=payload, model=model, task="document-question-answering") + return DocumentQuestionAnsweringOutputElement.parse_obj_as_list(response) + + async def feature_extraction( + self, + text: str, + *, + normalize: Optional[bool] = None, + prompt_name: Optional[str] = None, + truncate: Optional[bool] = None, + truncation_direction: Optional[Literal["Left", "Right"]] = None, + model: Optional[str] = None, + ) -> "np.ndarray": + """ + Generate embeddings for a given text. + + Args: + text (`str`): + The text to embed. + model (`str`, *optional*): + The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used. + Defaults to None. + normalize (`bool`, *optional*): + Whether to normalize the embeddings or not. Defaults to None. + Only available on server powered by Text-Embedding-Inference. + prompt_name (`str`, *optional*): + The name of the prompt that should be used by for encoding. If not set, no prompt will be applied. + Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. + For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ",...}, + then the sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?" + because the prompt text will be prepended before any text to encode. + truncate (`bool`, *optional*): + Whether to truncate the embeddings or not. Defaults to None. + Only available on server powered by Text-Embedding-Inference. + truncation_direction (`Literal["Left", "Right"]`, *optional*): + Which side of the input should be truncated when `truncate=True` is passed. + + Returns: + `np.ndarray`: The embedding representing the input text as a float32 numpy array. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.feature_extraction("Hi, who are you?") + array([[ 2.424802 , 2.93384 , 1.1750331 , ..., 1.240499, -0.13776633, -0.7889173 ], + [-0.42943227, -0.6364878 , -1.693462 , ..., 0.41978157, -2.4336355 , 0.6162071 ], + ..., + [ 0.28552425, -0.928395 , -1.2077185 , ..., 0.76810825, -2.1069427 , 0.6236161 ]], dtype=float32) + ``` + """ + payload: Dict = {"inputs": text} + if normalize is not None: + payload["normalize"] = normalize + if prompt_name is not None: + payload["prompt_name"] = prompt_name + if truncate is not None: + payload["truncate"] = truncate + if truncation_direction is not None: + payload["truncation_direction"] = truncation_direction + response = await self.post(json=payload, model=model, task="feature-extraction") + np = _import_numpy() + return np.array(_bytes_to_dict(response), dtype="float32") + + async def fill_mask(self, text: str, *, model: Optional[str] = None) -> List[FillMaskOutputElement]: + """ + Fill in a hole with a missing word (token to be precise). + + Args: + text (`str`): + a string to be filled from, must contain the [MASK] token (check model card for exact name of the mask). + model (`str`, *optional*): + The model to use for the fill mask task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended fill mask model will be used. + Defaults to None. + + Returns: + `List[FillMaskOutputElement]`: a list of [`FillMaskOutputElement`] items containing the predicted label, associated + probability, token reference, and completed text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.fill_mask("The goal of life is.") + [ + FillMaskOutputElement(score=0.06897063553333282, token=11098, token_str=' happiness', sequence='The goal of life is happiness.'), + FillMaskOutputElement(score=0.06554922461509705, token=45075, token_str=' immortality', sequence='The goal of life is immortality.') + ] + ``` + """ + response = await self.post(json={"inputs": text}, model=model, task="fill-mask") + return FillMaskOutputElement.parse_obj_as_list(response) + + async def image_classification( + self, + image: ContentT, + *, + model: Optional[str] = None, + ) -> List[ImageClassificationOutputElement]: + """ + Perform image classification on the given image using the specified model. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The image to classify. It can be raw bytes, an image file, or a URL to an online image. + model (`str`, *optional*): + The model to use for image classification. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for image classification will be used. + + Returns: + `List[ImageClassificationOutputElement]`: a list of [`ImageClassificationOutputElement`] items containing the predicted label and associated probability. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") + [ImageClassificationOutputElement(score=0.9779096841812134, label='Blenheim spaniel'), ...] + ``` + """ + response = await self.post(data=image, model=model, task="image-classification") + return ImageClassificationOutputElement.parse_obj_as_list(response) + + async def image_segmentation( + self, + image: ContentT, + *, + model: Optional[str] = None, + ) -> List[ImageSegmentationOutputElement]: + """ + Perform image segmentation on the given image using the specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The image to segment. It can be raw bytes, an image file, or a URL to an online image. + model (`str`, *optional*): + The model to use for image segmentation. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for image segmentation will be used. + + Returns: + `List[ImageSegmentationOutputElement]`: A list of [`ImageSegmentationOutputElement`] items containing the segmented masks and associated attributes. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.image_segmentation("cat.jpg"): + [ImageSegmentationOutputElement(score=0.989008, label='LABEL_184', mask=), ...] + ``` + """ + response = await self.post(data=image, model=model, task="image-segmentation") + output = ImageSegmentationOutputElement.parse_obj_as_list(response) + for item in output: + item.mask = _b64_to_image(item.mask) + return output + + async def image_to_image( + self, + image: ContentT, + prompt: Optional[str] = None, + *, + negative_prompt: Optional[str] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + model: Optional[str] = None, + **kwargs, + ) -> "Image": + """ + Perform image-to-image translation using a specified model. + + + + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image for translation. It can be raw bytes, an image file, or a URL to an online image. + prompt (`str`, *optional*): + The text prompt to guide the image generation. + negative_prompt (`str`, *optional*): + A negative prompt to guide the translation process. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*): + Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `Image`: The translated image. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> image = await client.image_to_image("cat.jpg", prompt="turn the cat into a tiger") + >>> image.save("tiger.jpg") + ``` + """ + parameters = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + **kwargs, + } + if all(parameter is None for parameter in parameters.values()): + # Either only an image to send => send as raw bytes + data = image + payload: Optional[Dict[str, Any]] = None + else: + # Or an image + some parameters => use base64 encoding + data = None + payload = {"inputs": _b64_encode(image)} + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value + + response = await self.post(json=payload, data=data, model=model, task="image-to-image") + return _bytes_to_image(response) + + async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> ImageToTextOutput: + """ + Takes an input image and return text. + + Models can have very different outputs depending on your use case (image captioning, optical character recognition + (OCR), Pix2Struct, etc). Please have a look to the model card to learn more about a model's specificities. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image to caption. It can be raw bytes, an image file, or a URL to an online image.. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + [`ImageToTextOutput`]: The generated text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.image_to_text("cat.jpg") + 'a cat standing in a grassy field ' + >>> await client.image_to_text("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg") + 'a dog laying on the grass next to a flower pot ' + ``` + """ + response = await self.post(data=image, model=model, task="image-to-text") + output = ImageToTextOutput.parse_obj(response) + return output[0] if isinstance(output, list) else output + + async def list_deployed_models( + self, frameworks: Union[None, str, Literal["all"], List[str]] = None + ) -> Dict[str, List[str]]: + """ + List models deployed on the Serverless Inference API service. + + This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that + are supported and account for 95% of the hosted models. However, if you want a complete list of models you can + specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested + in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more + frameworks are checked, the more time it will take. + ++ + This endpoint method does not return a live list of all models available for the Serverless Inference API service. + It searches over a cached list of models that were recently available and the list may not be up to date. + If you want to know the live status of a specific model, use [`~InferenceClient.get_model_status`]. + + + ++ + This endpoint method is mostly useful for discoverability. If you already know which model you want to use and want to + check its availability, you can directly use [`~InferenceClient.get_model_status`]. + + + + Args: + frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*): + The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to + "all", all available frameworks will be tested. It is also possible to provide a single framework or a + custom set of frameworks to check. + + Returns: + `Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs. + + Example: + ```py + # Must be run in an async contextthon + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + + # Discover zero-shot-classification models currently deployed + >>> models = await client.list_deployed_models() + >>> models["zero-shot-classification"] + ['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...] + + # List from only 1 framework + >>> await client.list_deployed_models("text-generation-inference") + {'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...} + ``` + """ + # Resolve which frameworks to check + if frameworks is None: + frameworks = MAIN_INFERENCE_API_FRAMEWORKS + elif frameworks == "all": + frameworks = ALL_INFERENCE_API_FRAMEWORKS + elif isinstance(frameworks, str): + frameworks = [frameworks] + frameworks = list(set(frameworks)) + + # Fetch them iteratively + models_by_task: Dict[str, List[str]] = {} + + def _unpack_response(framework: str, items: List[Dict]) -> None: + for model in items: + if framework == "sentence-transformers": + # Model running with the `sentence-transformers` framework can work with both tasks even if not + # branded as such in the API response + models_by_task.setdefault("feature-extraction", []).append(model["model_id"]) + models_by_task.setdefault("sentence-similarity", []).append(model["model_id"]) + else: + models_by_task.setdefault(model["task"], []).append(model["model_id"]) + + async def _fetch_framework(framework: str) -> None: + async with self._get_client_session() as client: + response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}", proxy=self.proxies) + response.raise_for_status() + _unpack_response(framework, await response.json()) + + import asyncio + + await asyncio.gather(*[_fetch_framework(framework) for framework in frameworks]) + + # Sort alphabetically for discoverability and return + for task, models in models_by_task.items(): + models_by_task[task] = sorted(set(models), key=lambda x: x.lower()) + return models_by_task + + async def object_detection( + self, + image: ContentT, + *, + model: Optional[str] = None, + ) -> List[ObjectDetectionOutputElement]: + """ + Perform object detection on the given image using the specified model. + ++ + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The image to detect objects on. It can be raw bytes, an image file, or a URL to an online image. + model (`str`, *optional*): + The model to use for object detection. Can be a model ID hosted on the Hugging Face Hub or a URL to a + deployed Inference Endpoint. If not provided, the default recommended model for object detection (DETR) will be used. + + Returns: + `List[ObjectDetectionOutputElement]`: A list of [`ObjectDetectionOutputElement`] items containing the bounding boxes and associated attributes. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + `ValueError`: + If the request output is not a List. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.object_detection("people.jpg"): + [ObjectDetectionOutputElement(score=0.9486683011054993, label='person', box=ObjectDetectionBoundingBox(xmin=59, ymin=39, xmax=420, ymax=510)), ...] + ``` + """ + # detect objects + response = await self.post(data=image, model=model, task="object-detection") + return ObjectDetectionOutputElement.parse_obj_as_list(response) + + async def question_answering( + self, question: str, context: str, *, model: Optional[str] = None + ) -> QuestionAnsweringOutputElement: + """ + Retrieve the answer to a question from a given text. + + Args: + question (`str`): + Question to be answered. + context (`str`): + The context of the question. + model (`str`): + The model to use for the question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. + + Returns: + [`QuestionAnsweringOutputElement`]: an question answering output containing the score, start index, end index, and answer. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.question_answering(question="What's my name?", context="My name is Clara and I live in Berkeley.") + QuestionAnsweringOutputElement(score=0.9326562285423279, start=11, end=16, answer='Clara') + ``` + """ + + payload: Dict[str, Any] = {"question": question, "context": context} + response = await self.post( + json=payload, + model=model, + task="question-answering", + ) + return QuestionAnsweringOutputElement.parse_obj_as_instance(response) + + async def sentence_similarity( + self, sentence: str, other_sentences: List[str], *, model: Optional[str] = None + ) -> List[float]: + """ + Compute the semantic similarity between a sentence and a list of other sentences by comparing their embeddings. + + Args: + sentence (`str`): + The main sentence to compare to others. + other_sentences (`List[str]`): + The list of sentences to compare to. + model (`str`, *optional*): + The model to use for the conversational task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended conversational model will be used. + Defaults to None. + + Returns: + `List[float]`: The embedding representing the input text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.sentence_similarity( + ... "Machine learning is so easy.", + ... other_sentences=[ + ... "Deep learning is so straightforward.", + ... "This is so difficult, like rocket science.", + ... "I can't believe how much I struggled with this.", + ... ], + ... ) + [0.7785726189613342, 0.45876261591911316, 0.2906220555305481] + ``` + """ + response = await self.post( + json={"inputs": {"source_sentence": sentence, "sentences": other_sentences}}, + model=model, + task="sentence-similarity", + ) + return _bytes_to_list(response) + + async def summarization( + self, + text: str, + *, + parameters: Optional[Dict[str, Any]] = None, + model: Optional[str] = None, + ) -> SummarizationOutput: + """ + Generate a summary of a given text using a specified model. + + Args: + text (`str`): + The input text to summarize. + parameters (`Dict[str, Any]`, *optional*): + Additional parameters for summarization. Check out this [page](https://huggingface.co/docs/api-inference/detailed_parameters#summarization-task) + for more details. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + [`SummarizationOutput`]: The generated summary text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.summarization("The Eiffel tower...") + SummarizationOutput(generated_text="The Eiffel tower is one of the most famous landmarks in the world....") + ``` + """ + payload: Dict[str, Any] = {"inputs": text} + if parameters is not None: + payload["parameters"] = parameters + response = await self.post(json=payload, model=model, task="summarization") + return SummarizationOutput.parse_obj_as_list(response)[0] + + async def table_question_answering( + self, table: Dict[str, Any], query: str, *, model: Optional[str] = None + ) -> TableQuestionAnsweringOutputElement: + """ + Retrieve the answer to a question from information given in a table. + + Args: + table (`str`): + A table of data represented as a dict of lists where entries are headers and the lists are all the + values, all lists must have the same size. + query (`str`): + The query in plain text that you want to ask the table. + model (`str`): + The model to use for the table-question-answering task. Can be a model ID hosted on the Hugging Face + Hub or a URL to a deployed Inference Endpoint. + + Returns: + [`TableQuestionAnsweringOutputElement`]: a table question answering output containing the answer, coordinates, cells and the aggregator used. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> query = "How many stars does the transformers repository have?" + >>> table = {"Repository": ["Transformers", "Datasets", "Tokenizers"], "Stars": ["36542", "4512", "3934"]} + >>> await client.table_question_answering(table, query, model="google/tapas-base-finetuned-wtq") + TableQuestionAnsweringOutputElement(answer='36542', coordinates=[[0, 1]], cells=['36542'], aggregator='AVERAGE') + ``` + """ + response = await self.post( + json={ + "query": query, + "table": table, + }, + model=model, + task="table-question-answering", + ) + return TableQuestionAnsweringOutputElement.parse_obj_as_instance(response) + + async def tabular_classification(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[str]: + """ + Classifying a target category (a group) based on a set of attributes. + + Args: + table (`Dict[str, Any]`): + Set of attributes to classify. + model (`str`, *optional*): + The model to use for the tabular classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended tabular classification model will be used. + Defaults to None. + + Returns: + `List`: a list of labels, one per row in the initial table. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> table = { + ... "fixed_acidity": ["7.4", "7.8", "10.3"], + ... "volatile_acidity": ["0.7", "0.88", "0.32"], + ... "citric_acid": ["0", "0", "0.45"], + ... "residual_sugar": ["1.9", "2.6", "6.4"], + ... "chlorides": ["0.076", "0.098", "0.073"], + ... "free_sulfur_dioxide": ["11", "25", "5"], + ... "total_sulfur_dioxide": ["34", "67", "13"], + ... "density": ["0.9978", "0.9968", "0.9976"], + ... "pH": ["3.51", "3.2", "3.23"], + ... "sulphates": ["0.56", "0.68", "0.82"], + ... "alcohol": ["9.4", "9.8", "12.6"], + ... } + >>> await client.tabular_classification(table=table, model="julien-c/wine-quality") + ["5", "5", "5"] + ``` + """ + response = await self.post(json={"table": table}, model=model, task="tabular-classification") + return _bytes_to_list(response) + + async def tabular_regression(self, table: Dict[str, Any], *, model: Optional[str] = None) -> List[float]: + """ + Predicting a numerical target value given a set of attributes/features in a table. + + Args: + table (`Dict[str, Any]`): + Set of attributes stored in a table. The attributes used to predict the target can be both numerical and categorical. + model (`str`, *optional*): + The model to use for the tabular regression task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended tabular regression model will be used. + Defaults to None. + + Returns: + `List`: a list of predicted numerical target values. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> table = { + ... "Height": ["11.52", "12.48", "12.3778"], + ... "Length1": ["23.2", "24", "23.9"], + ... "Length2": ["25.4", "26.3", "26.5"], + ... "Length3": ["30", "31.2", "31.1"], + ... "Species": ["Bream", "Bream", "Bream"], + ... "Width": ["4.02", "4.3056", "4.6961"], + ... } + >>> await client.tabular_regression(table, model="scikit-learn/Fish-Weight") + [110, 120, 130] + ``` + """ + response = await self.post(json={"table": table}, model=model, task="tabular-regression") + return _bytes_to_list(response) + + async def text_classification( + self, text: str, *, model: Optional[str] = None + ) -> List[TextClassificationOutputElement]: + """ + Perform text classification (e.g. sentiment-analysis) on the given text. + + Args: + text (`str`): + A string to be classified. + model (`str`, *optional*): + The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. + Defaults to None. + + Returns: + `List[TextClassificationOutputElement]`: a list of [`TextClassificationOutputElement`] items containing the predicted label and associated probability. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.text_classification("I like you") + [ + TextClassificationOutputElement(label='POSITIVE', score=0.9998695850372314), + TextClassificationOutputElement(label='NEGATIVE', score=0.0001304351753788069), + ] + ``` + """ + response = await self.post(json={"inputs": text}, model=model, task="text-classification") + return TextClassificationOutputElement.parse_obj_as_list(response)[0] # type: ignore [return-value] + + @overload + async def text_generation( # type: ignore + self, + prompt: str, + *, + details: Literal[False] = ..., + stream: Literal[False] = ..., + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> str: ... + + @overload + async def text_generation( # type: ignore + self, + prompt: str, + *, + details: Literal[True] = ..., + stream: Literal[False] = ..., + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> TextGenerationOutput: ... + + @overload + async def text_generation( # type: ignore + self, + prompt: str, + *, + details: Literal[False] = ..., + stream: Literal[True] = ..., + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> AsyncIterable[str]: ... + + @overload + async def text_generation( # type: ignore + self, + prompt: str, + *, + details: Literal[True] = ..., + stream: Literal[True] = ..., + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> AsyncIterable[TextGenerationStreamOutput]: ... + + @overload + async def text_generation( + self, + prompt: str, + *, + details: Literal[True] = ..., + stream: bool = ..., + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Union[TextGenerationOutput, AsyncIterable[TextGenerationStreamOutput]]: ... + + async def text_generation( + self, + prompt: str, + *, + details: bool = False, + stream: bool = False, + model: Optional[str] = None, + # Parameters from `TextGenerationInputGenerateParameters` (maintained manually) + adapter_id: Optional[str] = None, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + do_sample: Optional[bool] = False, # Manual default value + frequency_penalty: Optional[float] = None, + grammar: Optional[TextGenerationInputGrammarType] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = False, # Manual default value + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_n_tokens: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> Union[str, TextGenerationOutput, AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: + """ + Given a prompt, generate the following text. + + API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the + go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the + default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but + not exactly the same. This method is compatible with both approaches but some parameters are only available for + `text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process + continues correctly. + + To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference. + ++ + If you want to generate a response from chat messages, you should use the [`InferenceClient.chat_completion`] method. + It accepts a list of messages instead of a single text prompt and handles the chat templating for you. + + + + Args: + prompt (`str`): + Input text. + details (`bool`, *optional*): + By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens, + probabilities, seed, finish reason, etc.). Only available for models running on with the + `text-generation-inference` backend. + stream (`bool`, *optional*): + By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of + tokens to be returned. Only available for models running on with the `text-generation-inference` + backend. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + adapter_id (`str`, *optional*): + Lora adapter id. + best_of (`int`, *optional*): + Generate best_of sequences and return the one if the highest token logprobs. + decoder_input_details (`bool`, *optional*): + Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken + into account. Defaults to `False`. + do_sample (`bool`, *optional*): + Activate logits sampling + frequency_penalty (`float`, *optional*): + Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in + the text so far, decreasing the model's likelihood to repeat the same line verbatim. + grammar ([`TextGenerationInputGrammarType`], *optional*): + Grammar constraints. Can be either a JSONSchema or a regex. + max_new_tokens (`int`, *optional*): + Maximum number of generated tokens + repetition_penalty (`float`, *optional*): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + return_full_text (`bool`, *optional*): + Whether to prepend the prompt to the generated text + seed (`int`, *optional*): + Random sampling seed + stop (`List[str]`, *optional*): + Stop generating tokens if a member of `stop` is generated. + stop_sequences (`List[str]`, *optional*): + Deprecated argument. Use `stop` instead. + temperature (`float`, *optional*): + The value used to module the logits distribution. + top_n_tokens (`int`, *optional*): + Return information about the `top_n_tokens` most likely tokens at each generation step, instead of + just the sampled token. + top_k (`int`, *optional`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`, *optional`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + truncate (`int`, *optional`): + Truncate inputs tokens to the given size. + typical_p (`float`, *optional`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + watermark (`bool`, *optional`): + Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + + Returns: + `Union[str, TextGenerationOutput, Iterable[str], Iterable[TextGenerationStreamOutput]]`: + Generated text returned from the server: + - if `stream=False` and `details=False`, the generated text is returned as a `str` (default) + - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]` + - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.TextGenerationOutput`] + - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.TextGenerationStreamOutput`] + + Raises: + `ValidationError`: + If input values are not valid. No HTTP call is made to the server. + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + + # Case 1: generate text + >>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12) + '100% open source and built to be easy to use.' + + # Case 2: iterate over the generated tokens. Useful for large generation. + >>> async for token in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True): + ... print(token) + 100 + % + open + source + and + built + to + be + easy + to + use + . + + # Case 3: get more details about the generation process. + >>> await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True) + TextGenerationOutput( + generated_text='100% open source and built to be easy to use.', + details=TextGenerationDetails( + finish_reason='length', + generated_tokens=12, + seed=None, + prefill=[ + TextGenerationPrefillOutputToken(id=487, text='The', logprob=None), + TextGenerationPrefillOutputToken(id=53789, text=' hugging', logprob=-13.171875), + (...) + TextGenerationPrefillOutputToken(id=204, text=' ', logprob=-7.0390625) + ], + tokens=[ + TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), + TokenElement(id=16, text='%', logprob=-0.0463562, special=False), + (...) + TokenElement(id=25, text='.', logprob=-0.5703125, special=False) + ], + best_of_sequences=None + ) + ) + + # Case 4: iterate over the generated tokens with more details. + # Last object is more complete, containing the full generated text and the finish reason. + >>> async for details in await client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True): + ... print(details) + ... + TextGenerationStreamOutput(token=TokenElement(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None) + TextGenerationStreamOutput(token=TokenElement( + id=25, + text='.', + logprob=-0.5703125, + special=False), + generated_text='100% open source and built to be easy to use.', + details=TextGenerationStreamOutputStreamDetails(finish_reason='length', generated_tokens=12, seed=None) + ) + + # Case 5: generate constrained output using grammar + >>> response = await client.text_generation( + ... prompt="I saw a puppy a cat and a raccoon during my bike ride in the park", + ... model="HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + ... max_new_tokens=100, + ... repetition_penalty=1.3, + ... grammar={ + ... "type": "json", + ... "value": { + ... "properties": { + ... "location": {"type": "string"}, + ... "activity": {"type": "string"}, + ... "animals_seen": {"type": "integer", "minimum": 1, "maximum": 5}, + ... "animals": {"type": "array", "items": {"type": "string"}}, + ... }, + ... "required": ["location", "activity", "animals_seen", "animals"], + ... }, + ... }, + ... ) + >>> json.loads(response) + { + "activity": "bike riding", + "animals": ["puppy", "cat", "raccoon"], + "animals_seen": 3, + "location": "park" + } + ``` + """ + if decoder_input_details and not details: + warnings.warn( + "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that" + " the output from the server will be truncated." + ) + decoder_input_details = False + + if stop_sequences is not None: + warnings.warn( + "`stop_sequences` is a deprecated argument for `text_generation` task" + " and will be removed in version '0.28.0'. Use `stop` instead.", + FutureWarning, + ) + if stop is None: + stop = stop_sequences # use deprecated arg if provided + + # Build payload + parameters = { + "adapter_id": adapter_id, + "best_of": best_of, + "decoder_input_details": decoder_input_details, + "details": details, + "do_sample": do_sample, + "frequency_penalty": frequency_penalty, + "grammar": grammar, + "max_new_tokens": max_new_tokens, + "repetition_penalty": repetition_penalty, + "return_full_text": return_full_text, + "seed": seed, + "stop": stop if stop is not None else [], + "temperature": temperature, + "top_k": top_k, + "top_n_tokens": top_n_tokens, + "top_p": top_p, + "truncate": truncate, + "typical_p": typical_p, + "watermark": watermark, + } + parameters = {k: v for k, v in parameters.items() if v is not None} + payload = { + "inputs": prompt, + "parameters": parameters, + "stream": stream, + } + + # Remove some parameters if not a TGI server + unsupported_kwargs = _get_unsupported_text_generation_kwargs(model) + if len(unsupported_kwargs) > 0: + # The server does not support some parameters + # => means it is not a TGI server + # => remove unsupported parameters and warn the user + + ignored_parameters = [] + for key in unsupported_kwargs: + if parameters.get(key): + ignored_parameters.append(key) + parameters.pop(key, None) + if len(ignored_parameters) > 0: + warnings.warn( + "API endpoint/model for text-generation is not served via TGI. Ignoring following parameters:" + f" {', '.join(ignored_parameters)}.", + UserWarning, + ) + if details: + warnings.warn( + "API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will" + " be ignored meaning only the generated text will be returned.", + UserWarning, + ) + details = False + if stream: + raise ValueError( + "API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream." + " Please pass `stream=False` as input." + ) + + # Handle errors separately for more precise error messages + try: + bytes_output = await self.post(json=payload, model=model, task="text-generation", stream=stream) # type: ignore + except _import_aiohttp().ClientResponseError as e: + match = MODEL_KWARGS_NOT_USED_REGEX.search(e.response_error_payload["error"]) + if e.status == 400 and match: + unused_params = [kwarg.strip("' ") for kwarg in match.group(1).split(",")] + _set_unsupported_text_generation_kwargs(model, unused_params) + return await self.text_generation( # type: ignore + prompt=prompt, + details=details, + stream=stream, + model=model, + adapter_id=adapter_id, + best_of=best_of, + decoder_input_details=decoder_input_details, + do_sample=do_sample, + frequency_penalty=frequency_penalty, + grammar=grammar, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop, + temperature=temperature, + top_k=top_k, + top_n_tokens=top_n_tokens, + top_p=top_p, + truncate=truncate, + typical_p=typical_p, + watermark=watermark, + ) + raise_text_generation_error(e) + + # Parse output + if stream: + return _async_stream_text_generation_response(bytes_output, details) # type: ignore + + data = _bytes_to_dict(bytes_output) # type: ignore[arg-type] + + # Data can be a single element (dict) or an iterable of dicts where we select the first element of. + if isinstance(data, list): + data = data[0] + + return TextGenerationOutput.parse_obj_as_instance(data) if details else data["generated_text"] + + async def text_to_image( + self, + prompt: str, + *, + negative_prompt: Optional[str] = None, + height: Optional[float] = None, + width: Optional[float] = None, + num_inference_steps: Optional[float] = None, + guidance_scale: Optional[float] = None, + model: Optional[str] = None, + **kwargs, + ) -> "Image": + """ + Generate an image based on a given text using a specified model. + ++ + You must have `PIL` installed if you want to work with images (`pip install Pillow`). + + + + Args: + prompt (`str`): + The prompt to generate an image from. + negative_prompt (`str`, *optional*): + An optional negative prompt for the image generation. + height (`float`, *optional*): + The height in pixels of the image to generate. + width (`float`, *optional*): + The width in pixels of the image to generate. + num_inference_steps (`int`, *optional*): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*): + Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `Image`: The generated image. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + + >>> image = await client.text_to_image("An astronaut riding a horse on the moon.") + >>> image.save("astronaut.png") + + >>> image = await client.text_to_image( + ... "An astronaut riding a horse on the moon.", + ... negative_prompt="low resolution, blurry", + ... model="stabilityai/stable-diffusion-2-1", + ... ) + >>> image.save("better_astronaut.png") + ``` + """ + payload = {"inputs": prompt} + parameters = { + "negative_prompt": negative_prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + **kwargs, + } + for key, value in parameters.items(): + if value is not None: + payload.setdefault("parameters", {})[key] = value # type: ignore + response = await self.post(json=payload, model=model, task="text-to-image") + return _bytes_to_image(response) + + async def text_to_speech(self, text: str, *, model: Optional[str] = None) -> bytes: + """ + Synthesize an audio of a voice pronouncing a given text. + + Args: + text (`str`): + The text to synthesize. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `bytes`: The generated audio. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from pathlib import Path + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + + >>> audio = await client.text_to_speech("Hello world") + >>> Path("hello_world.flac").write_bytes(audio) + ``` + """ + return await self.post(json={"inputs": text}, model=model, task="text-to-speech") + + async def token_classification( + self, text: str, *, model: Optional[str] = None + ) -> List[TokenClassificationOutputElement]: + """ + Perform token classification on the given text. + Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. + + Args: + text (`str`): + A string to be classified. + model (`str`, *optional*): + The model to use for the token classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended token classification model will be used. + Defaults to None. + + Returns: + `List[TokenClassificationOutputElement]`: List of [`TokenClassificationOutputElement`] items containing the entity group, confidence score, word, start and end index. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.token_classification("My name is Sarah Jessica Parker but you can call me Jessica") + [ + TokenClassificationOutputElement( + entity_group='PER', + score=0.9971321225166321, + word='Sarah Jessica Parker', + start=11, + end=31, + ), + TokenClassificationOutputElement( + entity_group='PER', + score=0.9773476123809814, + word='Jessica', + start=52, + end=59, + ) + ] + ``` + """ + payload: Dict[str, Any] = {"inputs": text} + response = await self.post( + json=payload, + model=model, + task="token-classification", + ) + return TokenClassificationOutputElement.parse_obj_as_list(response) + + async def translation( + self, text: str, *, model: Optional[str] = None, src_lang: Optional[str] = None, tgt_lang: Optional[str] = None + ) -> TranslationOutput: + """ + Convert text from one language to another. + + Check out https://huggingface.co/tasks/translation for more information on how to choose the best model for + your specific use case. Source and target languages usually depend on the model. + However, it is possible to specify source and target languages for certain models. If you are working with one of these models, + you can use `src_lang` and `tgt_lang` arguments to pass the relevant information. + You can find this information in the model card. + + Args: + text (`str`): + A string to be translated. + model (`str`, *optional*): + The model to use for the translation task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended translation model will be used. + Defaults to None. + src_lang (`str`, *optional*): + Source language of the translation task, i.e. input language. Cannot be passed without `tgt_lang`. + tgt_lang (`str`, *optional*): + Target language of the translation task, i.e. output language. Cannot be passed without `src_lang`. + + Returns: + [`TranslationOutput`]: The generated translated text. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + `ValueError`: + If only one of the `src_lang` and `tgt_lang` arguments are provided. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.translation("My name is Wolfgang and I live in Berlin") + 'Mein Name ist Wolfgang und ich lebe in Berlin.' + >>> await client.translation("My name is Wolfgang and I live in Berlin", model="Helsinki-NLP/opus-mt-en-fr") + TranslationOutput(translation_text='Je m\'appelle Wolfgang et je vis Γ Berlin.') + ``` + + Specifying languages: + ```py + >>> client.translation("My name is Sarah Jessica Parker but you can call me Jessica", model="facebook/mbart-large-50-many-to-many-mmt", src_lang="en_XX", tgt_lang="fr_XX") + "Mon nom est Sarah Jessica Parker mais vous pouvez m\'appeler Jessica" + ``` + """ + # Throw error if only one of `src_lang` and `tgt_lang` was given + if src_lang is not None and tgt_lang is None: + raise ValueError("You cannot specify `src_lang` without specifying `tgt_lang`.") + + if src_lang is None and tgt_lang is not None: + raise ValueError("You cannot specify `tgt_lang` without specifying `src_lang`.") + + # If both `src_lang` and `tgt_lang` are given, pass them to the request body + payload: Dict = {"inputs": text} + if src_lang and tgt_lang: + payload["parameters"] = {"src_lang": src_lang, "tgt_lang": tgt_lang} + response = await self.post(json=payload, model=model, task="translation") + return TranslationOutput.parse_obj_as_list(response)[0] + + async def visual_question_answering( + self, + image: ContentT, + question: str, + *, + model: Optional[str] = None, + ) -> List[VisualQuestionAnsweringOutputElement]: + """ + Answering open-ended questions based on an image. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image for the context. It can be raw bytes, an image file, or a URL to an online image. + question (`str`): + Question to be answered. + model (`str`, *optional*): + The model to use for the visual question answering task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended visual question answering model will be used. + Defaults to None. + + Returns: + `List[VisualQuestionAnsweringOutputElement]`: a list of [`VisualQuestionAnsweringOutputElement`] items containing the predicted label and associated probability. + + Raises: + `InferenceTimeoutError`: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.visual_question_answering( + ... image="https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg", + ... question="What is the animal doing?" + ... ) + [ + VisualQuestionAnsweringOutputElement(score=0.778609573841095, answer='laying down'), + VisualQuestionAnsweringOutputElement(score=0.6957435607910156, answer='sitting'), + ] + ``` + """ + payload: Dict[str, Any] = {"question": question, "image": _b64_encode(image)} + response = await self.post(json=payload, model=model, task="visual-question-answering") + return VisualQuestionAnsweringOutputElement.parse_obj_as_list(response) + + async def zero_shot_classification( + self, + text: str, + labels: List[str], + *, + multi_label: bool = False, + hypothesis_template: Optional[str] = None, + model: Optional[str] = None, + ) -> List[ZeroShotClassificationOutputElement]: + """ + Provide as input a text and a set of candidate labels to classify the input text. + + Args: + text (`str`): + The input text to classify. + labels (`List[str]`): + List of strings. Each string is the verbalization of a possible label for the input text. + multi_label (`bool`): + Boolean. If True, the probability for each label is evaluated independently and multiple labels can have a probability close to 1 simultaneously or all probabilities can be close to 0. + If False, the labels are considered mutually exclusive and the probability over all labels always sums to 1. Defaults to False. + hypothesis_template (`str`, *optional*): + A template sentence string with curly brackets to which the label strings are added. The label strings are added at the position of the curly brackets "{}". + Zero-shot classifiers are based on NLI models, which evaluate if a hypothesis is entailed in another text or not. + For example, with hypothesis_template="This text is about {}." and labels=["economics", "politics"], the system internally creates the two hypotheses "This text is about economics." and "This text is about politics.". + The model then evaluates for both hypotheses if they are entailed in the provided `text` or not. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `List[ZeroShotClassificationOutputElement]`: List of [`ZeroShotClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example with `multi_label=False`: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> text = ( + ... "A new model offers an explanation for how the Galilean satellites formed around the solar system's" + ... "largest world. Konstantin Batygin did not set out to solve one of the solar system's most puzzling" + ... " mysteries when he went for a run up a hill in Nice, France." + ... ) + >>> labels = ["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology"] + >>> await client.zero_shot_classification(text, labels) + [ + ZeroShotClassificationOutputElement(label='scientific discovery', score=0.7961668968200684), + ZeroShotClassificationOutputElement(label='space & cosmos', score=0.18570658564567566), + ZeroShotClassificationOutputElement(label='microbiology', score=0.00730885099619627), + ZeroShotClassificationOutputElement(label='archeology', score=0.006258360575884581), + ZeroShotClassificationOutputElement(label='robots', score=0.004559356719255447), + ] + >>> await client.zero_shot_classification(text, labels, multi_label=True) + [ + ZeroShotClassificationOutputElement(label='scientific discovery', score=0.9829297661781311), + ZeroShotClassificationOutputElement(label='space & cosmos', score=0.755190908908844), + ZeroShotClassificationOutputElement(label='microbiology', score=0.0005462635890580714), + ZeroShotClassificationOutputElement(label='archeology', score=0.00047131875180639327), + ZeroShotClassificationOutputElement(label='robots', score=0.00030448526376858354), + ] + ``` + + Example with `multi_label=True` and a custom `hypothesis_template`: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.zero_shot_classification( + ... text="I really like our dinner and I'm very happy. I don't like the weather though.", + ... labels=["positive", "negative", "pessimistic", "optimistic"], + ... multi_label=True, + ... hypothesis_template="This text is {} towards the weather" + ... ) + [ + ZeroShotClassificationOutputElement(label='negative', score=0.9231801629066467), + ZeroShotClassificationOutputElement(label='pessimistic', score=0.8760990500450134), + ZeroShotClassificationOutputElement(label='optimistic', score=0.0008674879791215062), + ZeroShotClassificationOutputElement(label='positive', score=0.0005250611575320363) + ] + ``` + """ + + parameters = {"candidate_labels": labels, "multi_label": multi_label} + if hypothesis_template is not None: + parameters["hypothesis_template"] = hypothesis_template + + response = await self.post( + json={ + "inputs": text, + "parameters": parameters, + }, + task="zero-shot-classification", + model=model, + ) + output = _bytes_to_dict(response) + return [ + ZeroShotClassificationOutputElement.parse_obj_as_instance({"label": label, "score": score}) + for label, score in zip(output["labels"], output["scores"]) + ] + + async def zero_shot_image_classification( + self, image: ContentT, labels: List[str], *, model: Optional[str] = None + ) -> List[ZeroShotImageClassificationOutputElement]: + """ + Provide input image and text labels to predict text labels for the image. + + Args: + image (`Union[str, Path, bytes, BinaryIO]`): + The input image to caption. It can be raw bytes, an image file, or a URL to an online image. + labels (`List[str]`): + List of string possible labels. There must be at least 2 labels. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `List[ZeroShotImageClassificationOutputElement]`: List of [`ZeroShotImageClassificationOutputElement`] items containing the predicted labels and their confidence. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + + >>> await client.zero_shot_image_classification( + ... "https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg", + ... labels=["dog", "cat", "horse"], + ... ) + [ZeroShotImageClassificationOutputElement(label='dog', score=0.956),...] + ``` + """ + # Raise ValueError if input is less than 2 labels + if len(labels) < 2: + raise ValueError("You must specify at least 2 classes to compare.") + + response = await self.post( + json={"image": _b64_encode(image), "parameters": {"candidate_labels": ",".join(labels)}}, + model=model, + task="zero-shot-image-classification", + ) + return ZeroShotImageClassificationOutputElement.parse_obj_as_list(response) + + def _get_client_session(self, headers: Optional[Dict] = None) -> "ClientSession": + aiohttp = _import_aiohttp() + client_headers = self.headers.copy() + if headers is not None: + client_headers.update(headers) + + # Return a new aiohttp ClientSession with correct settings. + session = aiohttp.ClientSession( + headers=client_headers, + cookies=self.cookies, + timeout=aiohttp.ClientTimeout(self.timeout), + trust_env=self.trust_env, + ) + + # Keep track of sessions to close them later + self._sessions[session] = set() + + # Override the `._request` method to register responses to be closed + session._wrapped_request = session._request + + async def _request(method, url, **kwargs): + response = await session._wrapped_request(method, url, **kwargs) + self._sessions[session].add(response) + return response + + session._request = _request + + # Override the 'close' method to + # 1. close ongoing responses + # 2. deregister the session when closed + session._close = session.close + + async def close_session(): + for response in self._sessions[session]: + response.close() + await session._close() + self._sessions.pop(session, None) + + session.close = close_session + return session + + def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str: + model = model or self.model or self.base_url + + # If model is already a URL, ignore `task` and return directly + if model is not None and (model.startswith("http://") or model.startswith("https://")): + return model + + # # If no model but task is set => fetch the recommended one for this task + if model is None: + if task is None: + raise ValueError( + "You must specify at least a model (repo_id or URL) or a task, either when instantiating" + " `InferenceClient` or when making a request." + ) + model = self.get_recommended_model(task) + logger.info( + f"Using recommended model {model} for task {task}. Note that it is" + f" encouraged to explicitly set `model='{model}'` as the recommended" + " models list might get updated without prior notice." + ) + + # Compute InferenceAPI url + return ( + # Feature-extraction and sentence-similarity are the only cases where we handle models with several tasks. + f"{INFERENCE_ENDPOINT}/pipeline/{task}/{model}" + if task in ("feature-extraction", "sentence-similarity") + # Otherwise, we use the default endpoint + else f"{INFERENCE_ENDPOINT}/models/{model}" + ) + + @staticmethod + def get_recommended_model(task: str) -> str: + """ + Get the model Hugging Face recommends for the input task. + + Args: + task (`str`): + The Hugging Face task to get which model Hugging Face recommends. + All available tasks can be found [here](https://huggingface.co/tasks). + + Returns: + `str`: Name of the model recommended for the input task. + + Raises: + `ValueError`: If Hugging Face has no recommendation for the input task. + """ + model = _fetch_recommended_models().get(task) + if model is None: + raise ValueError( + f"Task {task} has no recommended model. Please specify a model" + " explicitly. Visit https://huggingface.co/tasks for more info." + ) + return model + + async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]: + """ + Get information about the deployed endpoint. + + This endpoint is only available on endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). + Endpoints powered by `transformers` return an empty payload. + + Args: + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `Dict[str, Any]`: Information about the endpoint. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient("meta-llama/Meta-Llama-3-70B-Instruct") + >>> await client.get_endpoint_info() + { + 'model_id': 'meta-llama/Meta-Llama-3-70B-Instruct', + 'model_sha': None, + 'model_dtype': 'torch.float16', + 'model_device_type': 'cuda', + 'model_pipeline_tag': None, + 'max_concurrent_requests': 128, + 'max_best_of': 2, + 'max_stop_sequences': 4, + 'max_input_length': 8191, + 'max_total_tokens': 8192, + 'waiting_served_ratio': 0.3, + 'max_batch_total_tokens': 1259392, + 'max_waiting_tokens': 20, + 'max_batch_size': None, + 'validation_workers': 32, + 'max_client_batch_size': 4, + 'version': '2.0.2', + 'sha': 'dccab72549635c7eb5ddb17f43f0b7cdff07c214', + 'docker_label': 'sha-dccab72' + } + ``` + """ + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if model.startswith(("http://", "https://")): + url = model.rstrip("/") + "/info" + else: + url = f"{INFERENCE_ENDPOINT}/models/{model}/info" + + async with self._get_client_session() as client: + response = await client.get(url, proxy=self.proxies) + response.raise_for_status() + return await response.json() + + async def health_check(self, model: Optional[str] = None) -> bool: + """ + Check the health of the deployed endpoint. + + Health check is only available with Inference Endpoints powered by Text-Generation-Inference (TGI) or Text-Embedding-Inference (TEI). + For Inference API, please use [`InferenceClient.get_model_status`] instead. + + Args: + model (`str`, *optional*): + URL of the Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + + Returns: + `bool`: True if everything is working fine. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient("https://jzgu0buei5.us-east-1.aws.endpoints.huggingface.cloud") + >>> await client.health_check() + True + ``` + """ + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if not model.startswith(("http://", "https://")): + raise ValueError( + "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`." + ) + url = model.rstrip("/") + "/health" + + async with self._get_client_session() as client: + response = await client.get(url, proxy=self.proxies) + return response.status == 200 + + async def get_model_status(self, model: Optional[str] = None) -> ModelStatus: + """ + Get the status of a model hosted on the Inference API. + ++ + This endpoint is mostly useful when you already know which model you want to use and want to check its + availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`]. + + + + Args: + model (`str`, *optional*): + Identifier of the model for witch the status gonna be checked. If model is not provided, + the model associated with this instance of [`InferenceClient`] will be used. Only InferenceAPI service can be checked so the + identifier cannot be a URL. + + + Returns: + [`ModelStatus`]: An instance of ModelStatus dataclass, containing information, + about the state of the model: load, state, compute type and framework. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> await client.get_model_status("meta-llama/Meta-Llama-3-8B-Instruct") + ModelStatus(loaded=True, state='Loaded', compute_type='gpu', framework='text-generation-inference') + ``` + """ + model = model or self.model + if model is None: + raise ValueError("Model id not provided.") + if model.startswith("https://"): + raise NotImplementedError("Model status is only available for Inference API endpoints.") + url = f"{INFERENCE_ENDPOINT}/status/{model}" + + async with self._get_client_session() as client: + response = await client.get(url, proxy=self.proxies) + response.raise_for_status() + response_data = await response.json() + + if "error" in response_data: + raise ValueError(response_data["error"]) + + return ModelStatus( + loaded=response_data["loaded"], + state=response_data["state"], + compute_type=response_data["compute_type"], + framework=response_data["framework"], + ) + + @property + def chat(self) -> "ProxyClientChat": + return ProxyClientChat(self) + + +class _ProxyClient: + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + def __init__(self, client: AsyncInferenceClient): + self._client = client + + +class ProxyClientChat(_ProxyClient): + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + @property + def completions(self) -> "ProxyClientChatCompletions": + return ProxyClientChatCompletions(self._client) + + +class ProxyClientChatCompletions(_ProxyClient): + """Proxy class to be able to call `client.chat.completion.create(...)` as OpenAI client.""" + + @property + def create(self): + return self._client.chat_completion diff --git a/huggingface_hub/inference/_generated/types/__init__.py b/huggingface_hub/inference/_generated/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db2793be23de19e0c8ca2f7a72cf6ded6d090471 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/__init__.py @@ -0,0 +1,135 @@ +# This file is auto-generated by `utils/generate_inference_types.py`. +# Do not modify it manually. +# +# ruff: noqa: F401 + +from .audio_classification import ( + AudioClassificationInput, + AudioClassificationOutputElement, + AudioClassificationParameters, +) +from .audio_to_audio import AudioToAudioInput, AudioToAudioOutputElement +from .automatic_speech_recognition import ( + AutomaticSpeechRecognitionGenerationParameters, + AutomaticSpeechRecognitionInput, + AutomaticSpeechRecognitionOutput, + AutomaticSpeechRecognitionOutputChunk, + AutomaticSpeechRecognitionParameters, +) +from .base import BaseInferenceType +from .chat_completion import ( + ChatCompletionInput, + ChatCompletionInputFunctionDefinition, + ChatCompletionInputFunctionName, + ChatCompletionInputGrammarType, + ChatCompletionInputMessage, + ChatCompletionInputMessageChunk, + ChatCompletionInputTool, + ChatCompletionInputToolTypeClass, + ChatCompletionInputURL, + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputFunctionDefinition, + ChatCompletionOutputLogprob, + ChatCompletionOutputLogprobs, + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputTopLogprob, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, + ChatCompletionStreamOutputDeltaToolCall, + ChatCompletionStreamOutputFunction, + ChatCompletionStreamOutputLogprob, + ChatCompletionStreamOutputLogprobs, + ChatCompletionStreamOutputTopLogprob, +) +from .depth_estimation import DepthEstimationInput, DepthEstimationOutput +from .document_question_answering import ( + DocumentQuestionAnsweringInput, + DocumentQuestionAnsweringInputData, + DocumentQuestionAnsweringOutputElement, + DocumentQuestionAnsweringParameters, +) +from .feature_extraction import FeatureExtractionInput +from .fill_mask import FillMaskInput, FillMaskOutputElement, FillMaskParameters +from .image_classification import ( + ImageClassificationInput, + ImageClassificationOutputElement, + ImageClassificationParameters, +) +from .image_segmentation import ImageSegmentationInput, ImageSegmentationOutputElement, ImageSegmentationParameters +from .image_to_image import ImageToImageInput, ImageToImageOutput, ImageToImageParameters, ImageToImageTargetSize +from .image_to_text import ImageToTextGenerationParameters, ImageToTextInput, ImageToTextOutput, ImageToTextParameters +from .object_detection import ( + ObjectDetectionBoundingBox, + ObjectDetectionInput, + ObjectDetectionOutputElement, + ObjectDetectionParameters, +) +from .question_answering import ( + QuestionAnsweringInput, + QuestionAnsweringInputData, + QuestionAnsweringOutputElement, + QuestionAnsweringParameters, +) +from .sentence_similarity import SentenceSimilarityInput, SentenceSimilarityInputData +from .summarization import SummarizationGenerationParameters, SummarizationInput, SummarizationOutput +from .table_question_answering import ( + TableQuestionAnsweringInput, + TableQuestionAnsweringInputData, + TableQuestionAnsweringOutputElement, +) +from .text2text_generation import Text2TextGenerationInput, Text2TextGenerationOutput, Text2TextGenerationParameters +from .text_classification import TextClassificationInput, TextClassificationOutputElement, TextClassificationParameters +from .text_generation import ( + TextGenerationInput, + TextGenerationInputGenerateParameters, + TextGenerationInputGrammarType, + TextGenerationOutput, + TextGenerationOutputBestOfSequence, + TextGenerationOutputDetails, + TextGenerationOutputPrefillToken, + TextGenerationOutputToken, + TextGenerationStreamOutput, + TextGenerationStreamOutputStreamDetails, + TextGenerationStreamOutputToken, +) +from .text_to_audio import TextToAudioGenerationParameters, TextToAudioInput, TextToAudioOutput, TextToAudioParameters +from .text_to_image import TextToImageInput, TextToImageOutput, TextToImageParameters, TextToImageTargetSize +from .token_classification import ( + TokenClassificationInput, + TokenClassificationOutputElement, + TokenClassificationParameters, +) +from .translation import TranslationGenerationParameters, TranslationInput, TranslationOutput +from .video_classification import ( + VideoClassificationInput, + VideoClassificationOutputElement, + VideoClassificationParameters, +) +from .visual_question_answering import ( + VisualQuestionAnsweringInput, + VisualQuestionAnsweringInputData, + VisualQuestionAnsweringOutputElement, + VisualQuestionAnsweringParameters, +) +from .zero_shot_classification import ( + ZeroShotClassificationInput, + ZeroShotClassificationInputData, + ZeroShotClassificationOutputElement, + ZeroShotClassificationParameters, +) +from .zero_shot_image_classification import ( + ZeroShotImageClassificationInput, + ZeroShotImageClassificationInputData, + ZeroShotImageClassificationOutputElement, + ZeroShotImageClassificationParameters, +) +from .zero_shot_object_detection import ( + ZeroShotObjectDetectionBoundingBox, + ZeroShotObjectDetectionInput, + ZeroShotObjectDetectionInputData, + ZeroShotObjectDetectionOutputElement, +) diff --git a/huggingface_hub/inference/_generated/types/audio_classification.py b/huggingface_hub/inference/_generated/types/audio_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..914ba44960b5edca2f182bd1c3f15e9f01bce3b9 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/audio_classification.py @@ -0,0 +1,43 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Literal, Optional + +from .base import BaseInferenceType + + +ClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] + + +@dataclass +class AudioClassificationParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Audio Classification + """ + + function_to_apply: Optional["ClassificationOutputTransform"] = None + top_k: Optional[int] = None + """When specified, limits the output to the top K most probable classes.""" + + +@dataclass +class AudioClassificationInput(BaseInferenceType): + """Inputs for Audio Classification inference""" + + inputs: Any + """The input audio data""" + parameters: Optional[AudioClassificationParameters] = None + """Additional inference parameters""" + + +@dataclass +class AudioClassificationOutputElement(BaseInferenceType): + """Outputs for Audio Classification inference""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/huggingface_hub/inference/_generated/types/audio_to_audio.py b/huggingface_hub/inference/_generated/types/audio_to_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..4f473ed106c7d168784ae8e96db18f46237d065e --- /dev/null +++ b/huggingface_hub/inference/_generated/types/audio_to_audio.py @@ -0,0 +1,31 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any + +from .base import BaseInferenceType + + +@dataclass +class AudioToAudioInput(BaseInferenceType): + """Inputs for Audio to Audio inference""" + + inputs: Any + """The input audio data""" + + +@dataclass +class AudioToAudioOutputElement(BaseInferenceType): + """Outputs of inference for the Audio To Audio task + A generated audio file with its label. + """ + + blob: Any + """The generated audio file.""" + content_type: str + """The content type of audio file.""" + label: str + """The label of the audio file.""" diff --git a/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py b/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py new file mode 100644 index 0000000000000000000000000000000000000000..24a5238ab6b33ea13df79a1ea197b4f07b39c1ec --- /dev/null +++ b/huggingface_hub/inference/_generated/types/automatic_speech_recognition.py @@ -0,0 +1,116 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, List, Literal, Optional, Union + +from .base import BaseInferenceType + + +EarlyStoppingEnum = Literal["never"] + + +@dataclass +class AutomaticSpeechRecognitionGenerationParameters(BaseInferenceType): + """Parametrization of the text generation process + Ad-hoc parametrization of the text generation process + """ + + do_sample: Optional[bool] = None + """Whether to use sampling instead of greedy decoding when generating new tokens.""" + early_stopping: Optional[Union[bool, "EarlyStoppingEnum"]] = None + """Controls the stopping condition for beam-based methods.""" + epsilon_cutoff: Optional[float] = None + """If set to float strictly between 0 and 1, only tokens with a conditional probability + greater than epsilon_cutoff will be sampled. In the paper, suggested values range from + 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language + Model Desmoothing](https://hf.co/papers/2210.15191) for more details. + """ + eta_cutoff: Optional[float] = None + """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to + float strictly between 0 and 1, a token is only considered if it is greater than either + eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter + term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In + the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. + See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) + for more details. + """ + max_length: Optional[int] = None + """The maximum length (in tokens) of the generated text, including the input.""" + max_new_tokens: Optional[int] = None + """The maximum number of tokens to generate. Takes precedence over maxLength.""" + min_length: Optional[int] = None + """The minimum length (in tokens) of the generated text, including the input.""" + min_new_tokens: Optional[int] = None + """The minimum number of tokens to generate. Takes precedence over maxLength.""" + num_beam_groups: Optional[int] = None + """Number of groups to divide num_beams into in order to ensure diversity among different + groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. + """ + num_beams: Optional[int] = None + """Number of beams to use for beam search.""" + penalty_alpha: Optional[float] = None + """The value balances the model confidence and the degeneration penalty in contrastive + search decoding. + """ + temperature: Optional[float] = None + """The value used to modulate the next token probabilities.""" + top_k: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" + top_p: Optional[float] = None + """If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to top_p or higher are kept for generation. + """ + typical_p: Optional[float] = None + """Local typicality measures how similar the conditional probability of predicting a target + token next is to the expected conditional probability of predicting a random token next, + given the partial text already generated. If set to float < 1, the smallest set of the + most locally typical tokens with probabilities that add up to typical_p or higher are + kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. + """ + use_cache: Optional[bool] = None + """Whether the model should use the past last key/values attentions to speed up decoding""" + + +@dataclass +class AutomaticSpeechRecognitionParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Automatic Speech Recognition + """ + + generate: Optional[AutomaticSpeechRecognitionGenerationParameters] = None + """Parametrization of the text generation process""" + return_timestamps: Optional[bool] = None + """Whether to output corresponding timestamps with the generated text""" + + +@dataclass +class AutomaticSpeechRecognitionInput(BaseInferenceType): + """Inputs for Automatic Speech Recognition inference""" + + inputs: Any + """The input audio data""" + parameters: Optional[AutomaticSpeechRecognitionParameters] = None + """Additional inference parameters""" + + +@dataclass +class AutomaticSpeechRecognitionOutputChunk(BaseInferenceType): + text: str + """A chunk of text identified by the model""" + timestamps: List[float] + """The start and end timestamps corresponding with the text""" + + +@dataclass +class AutomaticSpeechRecognitionOutput(BaseInferenceType): + """Outputs of inference for the Automatic Speech Recognition task""" + + text: str + """The recognized text.""" + chunks: Optional[List[AutomaticSpeechRecognitionOutputChunk]] = None + """When returnTimestamps is enabled, chunks contains a list of audio chunks identified by + the model. + """ diff --git a/huggingface_hub/inference/_generated/types/base.py b/huggingface_hub/inference/_generated/types/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e57b9e8c1e6c677b5b0ea6367e8db58212092014 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/base.py @@ -0,0 +1,140 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains a base class for all inference types.""" + +import inspect +import json +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Type, TypeVar, Union, get_args + + +T = TypeVar("T", bound="BaseInferenceType") + + +@dataclass +class BaseInferenceType(dict): + """Base class for all inference types. + + Object is a dataclass and a dict for backward compatibility but plan is to remove the dict part in the future. + + Handle parsing from dict, list and json strings in a permissive way to ensure future-compatibility (e.g. all fields + are made optional, and non-expected fields are added as dict attributes). + """ + + @classmethod + def parse_obj_as_list(cls: Type[T], data: Union[bytes, str, List, Dict]) -> List[T]: + """Alias to parse server response and return a single instance. + + See `parse_obj` for more details. + """ + output = cls.parse_obj(data) + if not isinstance(output, list): + raise ValueError(f"Invalid input data for {cls}. Expected a list, but got {type(output)}.") + return output + + @classmethod + def parse_obj_as_instance(cls: Type[T], data: Union[bytes, str, List, Dict]) -> T: + """Alias to parse server response and return a single instance. + + See `parse_obj` for more details. + """ + output = cls.parse_obj(data) + if isinstance(output, list): + raise ValueError(f"Invalid input data for {cls}. Expected a single instance, but got a list.") + return output + + @classmethod + def parse_obj(cls: Type[T], data: Union[bytes, str, List, Dict]) -> Union[List[T], T]: + """Parse server response as a dataclass or list of dataclasses. + + To enable future-compatibility, we want to handle cases where the server return more fields than expected. + In such cases, we don't want to raise an error but still create the dataclass object. Remaining fields are + added as dict attributes. + """ + # Parse server response (from bytes) + if isinstance(data, bytes): + data = data.decode() + if isinstance(data, str): + data = json.loads(data) + + # If a list, parse each item individually + if isinstance(data, List): + return [cls.parse_obj(d) for d in data] # type: ignore [misc] + + # At this point, we expect a dict + if not isinstance(data, dict): + raise ValueError(f"Invalid data type: {type(data)}") + + init_values = {} + other_values = {} + for key, value in data.items(): + key = normalize_key(key) + if key in cls.__dataclass_fields__ and cls.__dataclass_fields__[key].init: + if isinstance(value, dict) or isinstance(value, list): + field_type = cls.__dataclass_fields__[key].type + + # if `field_type` is a `BaseInferenceType`, parse it + if inspect.isclass(field_type) and issubclass(field_type, BaseInferenceType): + value = field_type.parse_obj(value) + + # otherwise, recursively parse nested dataclasses (if possible) + # `get_args` returns handle Union and Optional for us + else: + expected_types = get_args(field_type) + for expected_type in expected_types: + if getattr(expected_type, "_name", None) == "List": + expected_type = get_args(expected_type)[ + 0 + ] # assume same type for all items in the list + if inspect.isclass(expected_type) and issubclass(expected_type, BaseInferenceType): + value = expected_type.parse_obj(value) + break + init_values[key] = value + else: + other_values[key] = value + + # Make all missing fields default to None + # => ensure that dataclass initialization will never fail even if the server does not return all fields. + for key in cls.__dataclass_fields__: + if key not in init_values: + init_values[key] = None + + # Initialize dataclass with expected values + item = cls(**init_values) + + # Add remaining fields as dict attributes + item.update(other_values) + return item + + def __post_init__(self): + self.update(asdict(self)) + + def __setitem__(self, __key: Any, __value: Any) -> None: + # Hacky way to keep dataclass values in sync when dict is updated + super().__setitem__(__key, __value) + if __key in self.__dataclass_fields__ and getattr(self, __key, None) != __value: + self.__setattr__(__key, __value) + return + + def __setattr__(self, __name: str, __value: Any) -> None: + # Hacky way to keep dict values is sync when dataclass is updated + super().__setattr__(__name, __value) + if self.get(__name) != __value: + self[__name] = __value + return + + +def normalize_key(key: str) -> str: + # e.g "content-type" -> "content_type", "Accept" -> "accept" + return key.replace("-", "_").replace(" ", "_").lower() diff --git a/huggingface_hub/inference/_generated/types/chat_completion.py b/huggingface_hub/inference/_generated/types/chat_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..fa6e373140cab6bc312817ddf7751ce2ef69629c --- /dev/null +++ b/huggingface_hub/inference/_generated/types/chat_completion.py @@ -0,0 +1,280 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, List, Literal, Optional, Union + +from .base import BaseInferenceType + + +@dataclass +class ChatCompletionInputURL(BaseInferenceType): + url: str + + +ChatCompletionInputMessageChunkType = Literal["text", "image_url"] + + +@dataclass +class ChatCompletionInputMessageChunk(BaseInferenceType): + type: "ChatCompletionInputMessageChunkType" + image_url: Optional[ChatCompletionInputURL] = None + text: Optional[str] = None + + +@dataclass +class ChatCompletionInputMessage(BaseInferenceType): + content: Union[List[ChatCompletionInputMessageChunk], str] + role: str + name: Optional[str] = None + + +ChatCompletionInputGrammarTypeType = Literal["json", "regex"] + + +@dataclass +class ChatCompletionInputGrammarType(BaseInferenceType): + type: "ChatCompletionInputGrammarTypeType" + value: Any + """A string that represents a [JSON Schema](https://json-schema.org/). + JSON Schema is a declarative language that allows to annotate JSON documents + with types and descriptions. + """ + + +@dataclass +class ChatCompletionInputFunctionName(BaseInferenceType): + name: str + + +@dataclass +class ChatCompletionInputToolTypeClass(BaseInferenceType): + function: Optional[ChatCompletionInputFunctionName] = None + + +@dataclass +class ChatCompletionInputFunctionDefinition(BaseInferenceType): + arguments: Any + name: str + description: Optional[str] = None + + +@dataclass +class ChatCompletionInputTool(BaseInferenceType): + function: ChatCompletionInputFunctionDefinition + type: str + + +@dataclass +class ChatCompletionInput(BaseInferenceType): + """Chat Completion Input. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + messages: List[ChatCompletionInputMessage] + """A list of messages comprising the conversation so far.""" + frequency_penalty: Optional[float] = None + """Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + """ + logit_bias: Optional[List[float]] = None + """UNUSED + Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON + object that maps tokens + (specified by their token ID in the tokenizer) to an associated bias value from -100 to + 100. Mathematically, + the bias is added to the logits generated by the model prior to sampling. The exact + effect will vary per model, + but values between -1 and 1 should decrease or increase likelihood of selection; values + like -100 or 100 should + result in a ban or exclusive selection of the relevant token. + """ + logprobs: Optional[bool] = None + """Whether to return log probabilities of the output tokens or not. If true, returns the log + probabilities of each + output token returned in the content of message. + """ + max_tokens: Optional[int] = None + """The maximum number of tokens that can be generated in the chat completion.""" + model: Optional[str] = None + """[UNUSED] ID of the model to use. See the model endpoint compatibility table for details + on which models work with the Chat API. + """ + n: Optional[int] = None + """UNUSED + How many chat completion choices to generate for each input message. Note that you will + be charged based on the + number of generated tokens across all of the choices. Keep n as 1 to minimize costs. + """ + presence_penalty: Optional[float] = None + """Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + appear in the text so far, + increasing the model's likelihood to talk about new topics + """ + response_format: Optional[ChatCompletionInputGrammarType] = None + seed: Optional[int] = None + stop: Optional[List[str]] = None + """Up to 4 sequences where the API will stop generating further tokens.""" + stream: Optional[bool] = None + temperature: Optional[float] = None + """What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the + output more random, while + lower values like 0.2 will make it more focused and deterministic. + We generally recommend altering this or `top_p` but not both. + """ + tool_choice: Optional[Union[ChatCompletionInputToolTypeClass, str]] = None + tool_prompt: Optional[str] = None + """A prompt to be appended before the tools""" + tools: Optional[List[ChatCompletionInputTool]] = None + """A list of tools the model may call. Currently, only functions are supported as a tool. + Use this to provide a list of + functions the model may generate JSON inputs for. + """ + top_logprobs: Optional[int] = None + """An integer between 0 and 5 specifying the number of most likely tokens to return at each + token position, each with + an associated log probability. logprobs must be set to true if this parameter is used. + """ + top_p: Optional[float] = None + """An alternative to sampling with temperature, called nucleus sampling, where the model + considers the results of the + tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + probability mass are considered. + """ + + +@dataclass +class ChatCompletionOutputTopLogprob(BaseInferenceType): + logprob: float + token: str + + +@dataclass +class ChatCompletionOutputLogprob(BaseInferenceType): + logprob: float + token: str + top_logprobs: List[ChatCompletionOutputTopLogprob] + + +@dataclass +class ChatCompletionOutputLogprobs(BaseInferenceType): + content: List[ChatCompletionOutputLogprob] + + +@dataclass +class ChatCompletionOutputFunctionDefinition(BaseInferenceType): + arguments: Any + name: str + description: Optional[str] = None + + +@dataclass +class ChatCompletionOutputToolCall(BaseInferenceType): + function: ChatCompletionOutputFunctionDefinition + id: str + type: str + + +@dataclass +class ChatCompletionOutputMessage(BaseInferenceType): + role: str + content: Optional[str] = None + tool_calls: Optional[List[ChatCompletionOutputToolCall]] = None + + +@dataclass +class ChatCompletionOutputComplete(BaseInferenceType): + finish_reason: str + index: int + message: ChatCompletionOutputMessage + logprobs: Optional[ChatCompletionOutputLogprobs] = None + + +@dataclass +class ChatCompletionOutputUsage(BaseInferenceType): + completion_tokens: int + prompt_tokens: int + total_tokens: int + + +@dataclass +class ChatCompletionOutput(BaseInferenceType): + """Chat Completion Output. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + choices: List[ChatCompletionOutputComplete] + created: int + id: str + model: str + system_fingerprint: str + usage: ChatCompletionOutputUsage + + +@dataclass +class ChatCompletionStreamOutputFunction(BaseInferenceType): + arguments: str + name: Optional[str] = None + + +@dataclass +class ChatCompletionStreamOutputDeltaToolCall(BaseInferenceType): + function: ChatCompletionStreamOutputFunction + id: str + index: int + type: str + + +@dataclass +class ChatCompletionStreamOutputDelta(BaseInferenceType): + role: str + content: Optional[str] = None + tool_calls: Optional[ChatCompletionStreamOutputDeltaToolCall] = None + + +@dataclass +class ChatCompletionStreamOutputTopLogprob(BaseInferenceType): + logprob: float + token: str + + +@dataclass +class ChatCompletionStreamOutputLogprob(BaseInferenceType): + logprob: float + token: str + top_logprobs: List[ChatCompletionStreamOutputTopLogprob] + + +@dataclass +class ChatCompletionStreamOutputLogprobs(BaseInferenceType): + content: List[ChatCompletionStreamOutputLogprob] + + +@dataclass +class ChatCompletionStreamOutputChoice(BaseInferenceType): + delta: ChatCompletionStreamOutputDelta + index: int + finish_reason: Optional[str] = None + logprobs: Optional[ChatCompletionStreamOutputLogprobs] = None + + +@dataclass +class ChatCompletionStreamOutput(BaseInferenceType): + """Chat Completion Stream Output. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + choices: List[ChatCompletionStreamOutputChoice] + created: int + id: str + model: str + system_fingerprint: str diff --git a/huggingface_hub/inference/_generated/types/depth_estimation.py b/huggingface_hub/inference/_generated/types/depth_estimation.py new file mode 100644 index 0000000000000000000000000000000000000000..fbaa5feeadff9721ba543cb77121b98c17e3ee8c --- /dev/null +++ b/huggingface_hub/inference/_generated/types/depth_estimation.py @@ -0,0 +1,29 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from .base import BaseInferenceType + + +@dataclass +class DepthEstimationInput(BaseInferenceType): + """Inputs for Depth Estimation inference""" + + inputs: Any + """The input image data""" + parameters: Optional[Dict[str, Any]] = None + """Additional inference parameters""" + + +@dataclass +class DepthEstimationOutput(BaseInferenceType): + """Outputs of inference for the Depth Estimation task""" + + depth: Any + """The predicted depth as an image""" + predicted_depth: Any + """The predicted depth as a tensor""" diff --git a/huggingface_hub/inference/_generated/types/document_question_answering.py b/huggingface_hub/inference/_generated/types/document_question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..c68be4bde00a98fbce46a2ef6a93bb549d4d920b --- /dev/null +++ b/huggingface_hub/inference/_generated/types/document_question_answering.py @@ -0,0 +1,85 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, List, Optional, Union + +from .base import BaseInferenceType + + +@dataclass +class DocumentQuestionAnsweringInputData(BaseInferenceType): + """One (document, question) pair to answer""" + + image: Any + """The image on which the question is asked""" + question: str + """A question to ask of the document""" + + +@dataclass +class DocumentQuestionAnsweringParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Document Question Answering + """ + + doc_stride: Optional[int] = None + """If the words in the document are too long to fit with the question for the model, it will + be split in several chunks with some overlap. This argument controls the size of that + overlap. + """ + handle_impossible_answer: Optional[bool] = None + """Whether to accept impossible as an answer""" + lang: Optional[str] = None + """Language to use while running OCR. Defaults to english.""" + max_answer_len: Optional[int] = None + """The maximum length of predicted answers (e.g., only answers with a shorter length are + considered). + """ + max_question_len: Optional[int] = None + """The maximum length of the question after tokenization. It will be truncated if needed.""" + max_seq_len: Optional[int] = None + """The maximum length of the total sentence (context + question) in tokens of each chunk + passed to the model. The context will be split in several chunks (using doc_stride as + overlap) if needed. + """ + top_k: Optional[int] = None + """The number of answers to return (will be chosen by order of likelihood). Can return less + than top_k answers if there are not enough options available within the context. + """ + word_boxes: Optional[List[Union[List[float], str]]] = None + """A list of words and bounding boxes (normalized 0->1000). If provided, the inference will + skip the OCR step and use the provided bounding boxes instead. + """ + + +@dataclass +class DocumentQuestionAnsweringInput(BaseInferenceType): + """Inputs for Document Question Answering inference""" + + inputs: DocumentQuestionAnsweringInputData + """One (document, question) pair to answer""" + parameters: Optional[DocumentQuestionAnsweringParameters] = None + """Additional inference parameters""" + + +@dataclass +class DocumentQuestionAnsweringOutputElement(BaseInferenceType): + """Outputs of inference for the Document Question Answering task""" + + answer: str + """The answer to the question.""" + end: int + """The end word index of the answer (in the OCRβd version of the input or provided word + boxes). + """ + score: float + """The probability associated to the answer.""" + start: int + """The start word index of the answer (in the OCRβd version of the input or provided word + boxes). + """ + words: List[int] + """The index of each word/box pair that is in the answer""" diff --git a/huggingface_hub/inference/_generated/types/feature_extraction.py b/huggingface_hub/inference/_generated/types/feature_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..e706269de1187056f98aad582498976597019f18 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/feature_extraction.py @@ -0,0 +1,37 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Literal, Optional + +from .base import BaseInferenceType + + +FeatureExtractionInputTruncationDirection = Literal["Left", "Right"] + + +@dataclass +class FeatureExtractionInput(BaseInferenceType): + """Feature Extraction Input. + Auto-generated from TEI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tei-import.ts. + """ + + inputs: str + """The text to embed.""" + normalize: Optional[bool] = None + prompt_name: Optional[str] = None + """The name of the prompt that should be used by for encoding. If not set, no prompt + will be applied. + Must be a key in the `Sentence Transformers` configuration `prompts` dictionary. + For example if ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", + ...}, + then the sentence "What is the capital of France?" will be encoded as + "query: What is the capital of France?" because the prompt text will be prepended before + any text to encode. + """ + truncate: Optional[bool] = None + truncation_direction: Optional["FeatureExtractionInputTruncationDirection"] = None diff --git a/huggingface_hub/inference/_generated/types/fill_mask.py b/huggingface_hub/inference/_generated/types/fill_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..e1fddf96fbb7c76c8ffee0c170c6554c8b4e2bf8 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/fill_mask.py @@ -0,0 +1,50 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, List, Optional + +from .base import BaseInferenceType + + +@dataclass +class FillMaskParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Fill Mask + """ + + targets: Optional[List[str]] = None + """When passed, the model will limit the scores to the passed targets instead of looking up + in the whole vocabulary. If the provided targets are not in the model vocab, they will be + tokenized and the first resulting token will be used (with a warning, and that might be + slower). + """ + top_k: Optional[int] = None + """When passed, overrides the number of predictions to return.""" + + +@dataclass +class FillMaskInput(BaseInferenceType): + """Inputs for Fill Mask inference""" + + inputs: str + """The text with masked tokens""" + parameters: Optional[FillMaskParameters] = None + """Additional inference parameters""" + + +@dataclass +class FillMaskOutputElement(BaseInferenceType): + """Outputs of inference for the Fill Mask task""" + + score: float + """The corresponding probability""" + sequence: str + """The corresponding input with the mask token prediction.""" + token: int + """The predicted token id (to replace the masked one).""" + token_str: Any + fill_mask_output_token_str: Optional[str] = None + """The predicted token (to replace the masked one).""" diff --git a/huggingface_hub/inference/_generated/types/image_classification.py b/huggingface_hub/inference/_generated/types/image_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..fd52db005a0be62e7f063c0a16569a1fc2b273da --- /dev/null +++ b/huggingface_hub/inference/_generated/types/image_classification.py @@ -0,0 +1,43 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Literal, Optional + +from .base import BaseInferenceType + + +ClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] + + +@dataclass +class ImageClassificationParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Image Classification + """ + + function_to_apply: Optional["ClassificationOutputTransform"] = None + top_k: Optional[int] = None + """When specified, limits the output to the top K most probable classes.""" + + +@dataclass +class ImageClassificationInput(BaseInferenceType): + """Inputs for Image Classification inference""" + + inputs: Any + """The input image data""" + parameters: Optional[ImageClassificationParameters] = None + """Additional inference parameters""" + + +@dataclass +class ImageClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Image Classification task""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/huggingface_hub/inference/_generated/types/image_segmentation.py b/huggingface_hub/inference/_generated/types/image_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..67dd7c28b3cddd21d495ada70b7689a098accfd6 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/image_segmentation.py @@ -0,0 +1,52 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Literal, Optional + +from .base import BaseInferenceType + + +ImageSegmentationSubtask = Literal["instance", "panoptic", "semantic"] + + +@dataclass +class ImageSegmentationParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Image Segmentation + """ + + mask_threshold: Optional[float] = None + """Threshold to use when turning the predicted masks into binary values.""" + overlap_mask_area_threshold: Optional[float] = None + """Mask overlap threshold to eliminate small, disconnected segments.""" + subtask: Optional["ImageSegmentationSubtask"] = None + """Segmentation task to be performed, depending on model capabilities.""" + threshold: Optional[float] = None + """Probability threshold to filter out predicted masks.""" + + +@dataclass +class ImageSegmentationInput(BaseInferenceType): + """Inputs for Image Segmentation inference""" + + inputs: Any + """The input image data""" + parameters: Optional[ImageSegmentationParameters] = None + """Additional inference parameters""" + + +@dataclass +class ImageSegmentationOutputElement(BaseInferenceType): + """Outputs of inference for the Image Segmentation task + A predicted mask / segment + """ + + label: str + """The label of the predicted segment""" + mask: Any + """The corresponding mask as a black-and-white image""" + score: Optional[float] = None + """The score or confidence degreee the model has""" diff --git a/huggingface_hub/inference/_generated/types/image_to_image.py b/huggingface_hub/inference/_generated/types/image_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..8c208ede6f7f2fb73b5dd059fe71bc8d2c4ca140 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/image_to_image.py @@ -0,0 +1,55 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, List, Optional + +from .base import BaseInferenceType + + +@dataclass +class ImageToImageTargetSize(BaseInferenceType): + """The size in pixel of the output image""" + + height: int + width: int + + +@dataclass +class ImageToImageParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Image To Image + """ + + guidance_scale: Optional[float] = None + """For diffusion models. A higher guidance scale value encourages the model to generate + images closely linked to the text prompt at the expense of lower image quality. + """ + negative_prompt: Optional[List[str]] = None + """One or several prompt to guide what NOT to include in image generation.""" + num_inference_steps: Optional[int] = None + """For diffusion models. The number of denoising steps. More denoising steps usually lead to + a higher quality image at the expense of slower inference. + """ + target_size: Optional[ImageToImageTargetSize] = None + """The size in pixel of the output image""" + + +@dataclass +class ImageToImageInput(BaseInferenceType): + """Inputs for Image To Image inference""" + + inputs: Any + """The input image data""" + parameters: Optional[ImageToImageParameters] = None + """Additional inference parameters""" + + +@dataclass +class ImageToImageOutput(BaseInferenceType): + """Outputs of inference for the Image To Image task""" + + image: Any + """The output image""" diff --git a/huggingface_hub/inference/_generated/types/image_to_text.py b/huggingface_hub/inference/_generated/types/image_to_text.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebb9a9bc667bdb0d2afd7bb8e482fc18f6634d7 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/image_to_text.py @@ -0,0 +1,105 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Literal, Optional, Union + +from .base import BaseInferenceType + + +EarlyStoppingEnum = Literal["never"] + + +@dataclass +class ImageToTextGenerationParameters(BaseInferenceType): + """Parametrization of the text generation process + Ad-hoc parametrization of the text generation process + """ + + do_sample: Optional[bool] = None + """Whether to use sampling instead of greedy decoding when generating new tokens.""" + early_stopping: Optional[Union[bool, "EarlyStoppingEnum"]] = None + """Controls the stopping condition for beam-based methods.""" + epsilon_cutoff: Optional[float] = None + """If set to float strictly between 0 and 1, only tokens with a conditional probability + greater than epsilon_cutoff will be sampled. In the paper, suggested values range from + 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language + Model Desmoothing](https://hf.co/papers/2210.15191) for more details. + """ + eta_cutoff: Optional[float] = None + """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to + float strictly between 0 and 1, a token is only considered if it is greater than either + eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter + term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In + the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. + See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) + for more details. + """ + max_length: Optional[int] = None + """The maximum length (in tokens) of the generated text, including the input.""" + max_new_tokens: Optional[int] = None + """The maximum number of tokens to generate. Takes precedence over maxLength.""" + min_length: Optional[int] = None + """The minimum length (in tokens) of the generated text, including the input.""" + min_new_tokens: Optional[int] = None + """The minimum number of tokens to generate. Takes precedence over maxLength.""" + num_beam_groups: Optional[int] = None + """Number of groups to divide num_beams into in order to ensure diversity among different + groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. + """ + num_beams: Optional[int] = None + """Number of beams to use for beam search.""" + penalty_alpha: Optional[float] = None + """The value balances the model confidence and the degeneration penalty in contrastive + search decoding. + """ + temperature: Optional[float] = None + """The value used to modulate the next token probabilities.""" + top_k: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" + top_p: Optional[float] = None + """If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to top_p or higher are kept for generation. + """ + typical_p: Optional[float] = None + """Local typicality measures how similar the conditional probability of predicting a target + token next is to the expected conditional probability of predicting a random token next, + given the partial text already generated. If set to float < 1, the smallest set of the + most locally typical tokens with probabilities that add up to typical_p or higher are + kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. + """ + use_cache: Optional[bool] = None + """Whether the model should use the past last key/values attentions to speed up decoding""" + + +@dataclass +class ImageToTextParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Image To Text + """ + + generate: Optional[ImageToTextGenerationParameters] = None + """Parametrization of the text generation process""" + max_new_tokens: Optional[int] = None + """The amount of maximum tokens to generate.""" + + +@dataclass +class ImageToTextInput(BaseInferenceType): + """Inputs for Image To Text inference""" + + inputs: Any + """The input image data""" + parameters: Optional[ImageToTextParameters] = None + """Additional inference parameters""" + + +@dataclass +class ImageToTextOutput(BaseInferenceType): + """Outputs of inference for the Image To Text task""" + + generated_text: Any + image_to_text_output_generated_text: Optional[str] = None + """The generated text.""" diff --git a/huggingface_hub/inference/_generated/types/object_detection.py b/huggingface_hub/inference/_generated/types/object_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..42b03a841b793fd4cb301bf51695bd35054a6af2 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/object_detection.py @@ -0,0 +1,55 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Optional + +from .base import BaseInferenceType + + +@dataclass +class ObjectDetectionParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Object Detection + """ + + threshold: Optional[float] = None + """The probability necessary to make a prediction.""" + + +@dataclass +class ObjectDetectionInput(BaseInferenceType): + """Inputs for Object Detection inference""" + + inputs: Any + """The input image data""" + parameters: Optional[ObjectDetectionParameters] = None + """Additional inference parameters""" + + +@dataclass +class ObjectDetectionBoundingBox(BaseInferenceType): + """The predicted bounding box. Coordinates are relative to the top left corner of the input + image. + """ + + xmax: int + xmin: int + ymax: int + ymin: int + + +@dataclass +class ObjectDetectionOutputElement(BaseInferenceType): + """Outputs of inference for the Object Detection task""" + + box: ObjectDetectionBoundingBox + """The predicted bounding box. Coordinates are relative to the top left corner of the input + image. + """ + label: str + """The predicted label for the bounding box""" + score: float + """The associated score / probability""" diff --git a/huggingface_hub/inference/_generated/types/question_answering.py b/huggingface_hub/inference/_generated/types/question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..3810fc594af5cf0712cb0cb0db077383220b175a --- /dev/null +++ b/huggingface_hub/inference/_generated/types/question_answering.py @@ -0,0 +1,77 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Optional + +from .base import BaseInferenceType + + +@dataclass +class QuestionAnsweringInputData(BaseInferenceType): + """One (context, question) pair to answer""" + + context: str + """The context to be used for answering the question""" + question: str + """The question to be answered""" + + +@dataclass +class QuestionAnsweringParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Question Answering + """ + + align_to_words: Optional[bool] = None + """Attempts to align the answer to real words. Improves quality on space separated + languages. Might hurt on non-space-separated languages (like Japanese or Chinese) + """ + doc_stride: Optional[int] = None + """If the context is too long to fit with the question for the model, it will be split in + several chunks with some overlap. This argument controls the size of that overlap. + """ + handle_impossible_answer: Optional[bool] = None + """Whether to accept impossible as an answer.""" + max_answer_len: Optional[int] = None + """The maximum length of predicted answers (e.g., only answers with a shorter length are + considered). + """ + max_question_len: Optional[int] = None + """The maximum length of the question after tokenization. It will be truncated if needed.""" + max_seq_len: Optional[int] = None + """The maximum length of the total sentence (context + question) in tokens of each chunk + passed to the model. The context will be split in several chunks (using docStride as + overlap) if needed. + """ + top_k: Optional[int] = None + """The number of answers to return (will be chosen by order of likelihood). Note that we + return less than topk answers if there are not enough options available within the + context. + """ + + +@dataclass +class QuestionAnsweringInput(BaseInferenceType): + """Inputs for Question Answering inference""" + + inputs: QuestionAnsweringInputData + """One (context, question) pair to answer""" + parameters: Optional[QuestionAnsweringParameters] = None + """Additional inference parameters""" + + +@dataclass +class QuestionAnsweringOutputElement(BaseInferenceType): + """Outputs of inference for the Question Answering task""" + + answer: str + """The answer to the question.""" + end: int + """The character position in the input where the answer ends.""" + score: float + """The probability associated to the answer.""" + start: int + """The character position in the input where the answer begins.""" diff --git a/huggingface_hub/inference/_generated/types/sentence_similarity.py b/huggingface_hub/inference/_generated/types/sentence_similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..944bfccbf76e8c322dbf95a286746c6e1e25a55b --- /dev/null +++ b/huggingface_hub/inference/_generated/types/sentence_similarity.py @@ -0,0 +1,28 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from .base import BaseInferenceType + + +@dataclass +class SentenceSimilarityInputData(BaseInferenceType): + sentences: List[str] + """A list of strings which will be compared against the source_sentence.""" + source_sentence: str + """The string that you wish to compare the other strings with. This can be a phrase, + sentence, or longer passage, depending on the model being used. + """ + + +@dataclass +class SentenceSimilarityInput(BaseInferenceType): + """Inputs for Sentence similarity inference""" + + inputs: SentenceSimilarityInputData + parameters: Optional[Dict[str, Any]] = None + """Additional inference parameters""" diff --git a/huggingface_hub/inference/_generated/types/summarization.py b/huggingface_hub/inference/_generated/types/summarization.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a00e53264bd9a53a24d2ee7b12f428c068a117 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/summarization.py @@ -0,0 +1,46 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Dict, Literal, Optional + +from .base import BaseInferenceType + + +SummarizationGenerationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"] + + +@dataclass +class SummarizationGenerationParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Text2text Generation + """ + + clean_up_tokenization_spaces: Optional[bool] = None + """Whether to clean up the potential extra spaces in the text output.""" + generate_parameters: Optional[Dict[str, Any]] = None + """Additional parametrization of the text generation algorithm""" + truncation: Optional["SummarizationGenerationTruncationStrategy"] = None + """The truncation strategy to use""" + + +@dataclass +class SummarizationInput(BaseInferenceType): + """Inputs for Summarization inference + Inputs for Text2text Generation inference + """ + + inputs: str + """The input text data""" + parameters: Optional[SummarizationGenerationParameters] = None + """Additional inference parameters""" + + +@dataclass +class SummarizationOutput(BaseInferenceType): + """Outputs of inference for the Summarization task""" + + summary_text: str + """The summarized text.""" diff --git a/huggingface_hub/inference/_generated/types/table_question_answering.py b/huggingface_hub/inference/_generated/types/table_question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb9fff641fd4ed2d8e797e59ae7b5f21f94c838 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/table_question_answering.py @@ -0,0 +1,45 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from .base import BaseInferenceType + + +@dataclass +class TableQuestionAnsweringInputData(BaseInferenceType): + """One (table, question) pair to answer""" + + question: str + """The question to be answered about the table""" + table: Dict[str, List[str]] + """The table to serve as context for the questions""" + + +@dataclass +class TableQuestionAnsweringInput(BaseInferenceType): + """Inputs for Table Question Answering inference""" + + inputs: TableQuestionAnsweringInputData + """One (table, question) pair to answer""" + parameters: Optional[Dict[str, Any]] = None + """Additional inference parameters""" + + +@dataclass +class TableQuestionAnsweringOutputElement(BaseInferenceType): + """Outputs of inference for the Table Question Answering task""" + + answer: str + """The answer of the question given the table. If there is an aggregator, the answer will be + preceded by `AGGREGATOR >`. + """ + cells: List[str] + """List of strings made up of the answer cell values.""" + coordinates: List[List[int]] + """Coordinates of the cells of the answers.""" + aggregator: Optional[str] = None + """If the model has an aggregator, this returns the aggregator.""" diff --git a/huggingface_hub/inference/_generated/types/text2text_generation.py b/huggingface_hub/inference/_generated/types/text2text_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..955494c5ef6b86e12b3927dfd90e44a5db25c2e6 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/text2text_generation.py @@ -0,0 +1,45 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Dict, Literal, Optional + +from .base import BaseInferenceType + + +Text2TextGenerationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"] + + +@dataclass +class Text2TextGenerationParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Text2text Generation + """ + + clean_up_tokenization_spaces: Optional[bool] = None + """Whether to clean up the potential extra spaces in the text output.""" + generate_parameters: Optional[Dict[str, Any]] = None + """Additional parametrization of the text generation algorithm""" + truncation: Optional["Text2TextGenerationTruncationStrategy"] = None + """The truncation strategy to use""" + + +@dataclass +class Text2TextGenerationInput(BaseInferenceType): + """Inputs for Text2text Generation inference""" + + inputs: str + """The input text data""" + parameters: Optional[Text2TextGenerationParameters] = None + """Additional inference parameters""" + + +@dataclass +class Text2TextGenerationOutput(BaseInferenceType): + """Outputs of inference for the Text2text Generation task""" + + generated_text: Any + text2_text_generation_output_generated_text: Optional[str] = None + """The generated text.""" diff --git a/huggingface_hub/inference/_generated/types/text_classification.py b/huggingface_hub/inference/_generated/types/text_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..bf61a4eebcf367b4ab15e8970bfac8e1d8f8458d --- /dev/null +++ b/huggingface_hub/inference/_generated/types/text_classification.py @@ -0,0 +1,43 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Literal, Optional + +from .base import BaseInferenceType + + +ClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] + + +@dataclass +class TextClassificationParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Text Classification + """ + + function_to_apply: Optional["ClassificationOutputTransform"] = None + top_k: Optional[int] = None + """When specified, limits the output to the top K most probable classes.""" + + +@dataclass +class TextClassificationInput(BaseInferenceType): + """Inputs for Text Classification inference""" + + inputs: str + """The text to classify""" + parameters: Optional[TextClassificationParameters] = None + """Additional inference parameters""" + + +@dataclass +class TextClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Text Classification task""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/huggingface_hub/inference/_generated/types/text_generation.py b/huggingface_hub/inference/_generated/types/text_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..27c70c7e2b201fea407abe166feb7c3cc4a28fff --- /dev/null +++ b/huggingface_hub/inference/_generated/types/text_generation.py @@ -0,0 +1,168 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, List, Literal, Optional + +from .base import BaseInferenceType + + +TypeEnum = Literal["json", "regex"] + + +@dataclass +class TextGenerationInputGrammarType(BaseInferenceType): + type: "TypeEnum" + value: Any + """A string that represents a [JSON Schema](https://json-schema.org/). + JSON Schema is a declarative language that allows to annotate JSON documents + with types and descriptions. + """ + + +@dataclass +class TextGenerationInputGenerateParameters(BaseInferenceType): + adapter_id: Optional[str] = None + """Lora adapter id""" + best_of: Optional[int] = None + """Generate best_of sequences and return the one if the highest token logprobs.""" + decoder_input_details: Optional[bool] = None + """Whether to return decoder input token logprobs and ids.""" + details: Optional[bool] = None + """Whether to return generation details.""" + do_sample: Optional[bool] = None + """Activate logits sampling.""" + frequency_penalty: Optional[float] = None + """The parameter for frequency penalty. 1.0 means no penalty + Penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat the same line verbatim. + """ + grammar: Optional[TextGenerationInputGrammarType] = None + max_new_tokens: Optional[int] = None + """Maximum number of tokens to generate.""" + repetition_penalty: Optional[float] = None + """The parameter for repetition penalty. 1.0 means no penalty. + See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + """ + return_full_text: Optional[bool] = None + """Whether to prepend the prompt to the generated text""" + seed: Optional[int] = None + """Random sampling seed.""" + stop: Optional[List[str]] = None + """Stop generating tokens if a member of `stop` is generated.""" + temperature: Optional[float] = None + """The value used to module the logits distribution.""" + top_k: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" + top_n_tokens: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for top-n-filtering.""" + top_p: Optional[float] = None + """Top-p value for nucleus sampling.""" + truncate: Optional[int] = None + """Truncate inputs tokens to the given size.""" + typical_p: Optional[float] = None + """Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) + for more information. + """ + watermark: Optional[bool] = None + """Watermarking with [A Watermark for Large Language + Models](https://arxiv.org/abs/2301.10226). + """ + + +@dataclass +class TextGenerationInput(BaseInferenceType): + """Text Generation Input. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + inputs: str + parameters: Optional[TextGenerationInputGenerateParameters] = None + stream: Optional[bool] = None + + +TextGenerationOutputFinishReason = Literal["length", "eos_token", "stop_sequence"] + + +@dataclass +class TextGenerationOutputPrefillToken(BaseInferenceType): + id: int + logprob: float + text: str + + +@dataclass +class TextGenerationOutputToken(BaseInferenceType): + id: int + logprob: float + special: bool + text: str + + +@dataclass +class TextGenerationOutputBestOfSequence(BaseInferenceType): + finish_reason: "TextGenerationOutputFinishReason" + generated_text: str + generated_tokens: int + prefill: List[TextGenerationOutputPrefillToken] + tokens: List[TextGenerationOutputToken] + seed: Optional[int] = None + top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None + + +@dataclass +class TextGenerationOutputDetails(BaseInferenceType): + finish_reason: "TextGenerationOutputFinishReason" + generated_tokens: int + prefill: List[TextGenerationOutputPrefillToken] + tokens: List[TextGenerationOutputToken] + best_of_sequences: Optional[List[TextGenerationOutputBestOfSequence]] = None + seed: Optional[int] = None + top_tokens: Optional[List[List[TextGenerationOutputToken]]] = None + + +@dataclass +class TextGenerationOutput(BaseInferenceType): + """Text Generation Output. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + generated_text: str + details: Optional[TextGenerationOutputDetails] = None + + +@dataclass +class TextGenerationStreamOutputStreamDetails(BaseInferenceType): + finish_reason: "TextGenerationOutputFinishReason" + generated_tokens: int + seed: Optional[int] = None + + +@dataclass +class TextGenerationStreamOutputToken(BaseInferenceType): + id: int + logprob: float + special: bool + text: str + + +@dataclass +class TextGenerationStreamOutput(BaseInferenceType): + """Text Generation Stream Output. + Auto-generated from TGI specs. + For more details, check out + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-tgi-import.ts. + """ + + index: int + token: TextGenerationStreamOutputToken + details: Optional[TextGenerationStreamOutputStreamDetails] = None + generated_text: Optional[str] = None + top_tokens: Optional[List[TextGenerationStreamOutputToken]] = None diff --git a/huggingface_hub/inference/_generated/types/text_to_audio.py b/huggingface_hub/inference/_generated/types/text_to_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..dd8369de4b26cf8ef38cf8cfbafdc1a8bb12d552 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/text_to_audio.py @@ -0,0 +1,105 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Literal, Optional, Union + +from .base import BaseInferenceType + + +EarlyStoppingEnum = Literal["never"] + + +@dataclass +class TextToAudioGenerationParameters(BaseInferenceType): + """Parametrization of the text generation process + Ad-hoc parametrization of the text generation process + """ + + do_sample: Optional[bool] = None + """Whether to use sampling instead of greedy decoding when generating new tokens.""" + early_stopping: Optional[Union[bool, "EarlyStoppingEnum"]] = None + """Controls the stopping condition for beam-based methods.""" + epsilon_cutoff: Optional[float] = None + """If set to float strictly between 0 and 1, only tokens with a conditional probability + greater than epsilon_cutoff will be sampled. In the paper, suggested values range from + 3e-4 to 9e-4, depending on the size of the model. See [Truncation Sampling as Language + Model Desmoothing](https://hf.co/papers/2210.15191) for more details. + """ + eta_cutoff: Optional[float] = None + """Eta sampling is a hybrid of locally typical sampling and epsilon sampling. If set to + float strictly between 0 and 1, a token is only considered if it is greater than either + eta_cutoff or sqrt(eta_cutoff) * exp(-entropy(softmax(next_token_logits))). The latter + term is intuitively the expected next token probability, scaled by sqrt(eta_cutoff). In + the paper, suggested values range from 3e-4 to 2e-3, depending on the size of the model. + See [Truncation Sampling as Language Model Desmoothing](https://hf.co/papers/2210.15191) + for more details. + """ + max_length: Optional[int] = None + """The maximum length (in tokens) of the generated text, including the input.""" + max_new_tokens: Optional[int] = None + """The maximum number of tokens to generate. Takes precedence over maxLength.""" + min_length: Optional[int] = None + """The minimum length (in tokens) of the generated text, including the input.""" + min_new_tokens: Optional[int] = None + """The minimum number of tokens to generate. Takes precedence over maxLength.""" + num_beam_groups: Optional[int] = None + """Number of groups to divide num_beams into in order to ensure diversity among different + groups of beams. See [this paper](https://hf.co/papers/1610.02424) for more details. + """ + num_beams: Optional[int] = None + """Number of beams to use for beam search.""" + penalty_alpha: Optional[float] = None + """The value balances the model confidence and the degeneration penalty in contrastive + search decoding. + """ + temperature: Optional[float] = None + """The value used to modulate the next token probabilities.""" + top_k: Optional[int] = None + """The number of highest probability vocabulary tokens to keep for top-k-filtering.""" + top_p: Optional[float] = None + """If set to float < 1, only the smallest set of most probable tokens with probabilities + that add up to top_p or higher are kept for generation. + """ + typical_p: Optional[float] = None + """Local typicality measures how similar the conditional probability of predicting a target + token next is to the expected conditional probability of predicting a random token next, + given the partial text already generated. If set to float < 1, the smallest set of the + most locally typical tokens with probabilities that add up to typical_p or higher are + kept for generation. See [this paper](https://hf.co/papers/2202.00666) for more details. + """ + use_cache: Optional[bool] = None + """Whether the model should use the past last key/values attentions to speed up decoding""" + + +@dataclass +class TextToAudioParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Text To Audio + """ + + generate: Optional[TextToAudioGenerationParameters] = None + """Parametrization of the text generation process""" + + +@dataclass +class TextToAudioInput(BaseInferenceType): + """Inputs for Text To Audio inference""" + + inputs: str + """The input text data""" + parameters: Optional[TextToAudioParameters] = None + """Additional inference parameters""" + + +@dataclass +class TextToAudioOutput(BaseInferenceType): + """Outputs of inference for the Text To Audio task""" + + audio: Any + """The generated audio waveform.""" + sampling_rate: Any + text_to_audio_output_sampling_rate: Optional[float] = None + """The sampling rate of the generated audio waveform.""" diff --git a/huggingface_hub/inference/_generated/types/text_to_image.py b/huggingface_hub/inference/_generated/types/text_to_image.py new file mode 100644 index 0000000000000000000000000000000000000000..40e53ab016d3a6f2098d26eadab9cf51805c31b1 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/text_to_image.py @@ -0,0 +1,57 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, List, Optional + +from .base import BaseInferenceType + + +@dataclass +class TextToImageTargetSize(BaseInferenceType): + """The size in pixel of the output image""" + + height: int + width: int + + +@dataclass +class TextToImageParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Text To Image + """ + + guidance_scale: Optional[float] = None + """For diffusion models. A higher guidance scale value encourages the model to generate + images closely linked to the text prompt at the expense of lower image quality. + """ + negative_prompt: Optional[List[str]] = None + """One or several prompt to guide what NOT to include in image generation.""" + num_inference_steps: Optional[int] = None + """For diffusion models. The number of denoising steps. More denoising steps usually lead to + a higher quality image at the expense of slower inference. + """ + scheduler: Optional[str] = None + """For diffusion models. Override the scheduler with a compatible one""" + target_size: Optional[TextToImageTargetSize] = None + """The size in pixel of the output image""" + + +@dataclass +class TextToImageInput(BaseInferenceType): + """Inputs for Text To Image inference""" + + inputs: str + """The input text data (sometimes called "prompt\"""" + parameters: Optional[TextToImageParameters] = None + """Additional inference parameters""" + + +@dataclass +class TextToImageOutput(BaseInferenceType): + """Outputs of inference for the Text To Image task""" + + image: Any + """The generated image""" diff --git a/huggingface_hub/inference/_generated/types/token_classification.py b/huggingface_hub/inference/_generated/types/token_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..2d60ea27eedbfe28096435c84e4002c0d9a64bc6 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/token_classification.py @@ -0,0 +1,53 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, List, Literal, Optional + +from .base import BaseInferenceType + + +TokenClassificationAggregationStrategy = Literal["none", "simple", "first", "average", "max"] + + +@dataclass +class TokenClassificationParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Token Classification + """ + + aggregation_strategy: Optional["TokenClassificationAggregationStrategy"] = None + """The strategy used to fuse tokens based on model predictions""" + ignore_labels: Optional[List[str]] = None + """A list of labels to ignore""" + stride: Optional[int] = None + """The number of overlapping tokens between chunks when splitting the input text.""" + + +@dataclass +class TokenClassificationInput(BaseInferenceType): + """Inputs for Token Classification inference""" + + inputs: str + """The input text data""" + parameters: Optional[TokenClassificationParameters] = None + """Additional inference parameters""" + + +@dataclass +class TokenClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Token Classification task""" + + label: Any + score: float + """The associated score / probability""" + end: Optional[int] = None + """The character position in the input where this group ends.""" + entity_group: Optional[str] = None + """The predicted label for that group of tokens""" + start: Optional[int] = None + """The character position in the input where this group begins.""" + word: Optional[str] = None + """The corresponding text""" diff --git a/huggingface_hub/inference/_generated/types/translation.py b/huggingface_hub/inference/_generated/types/translation.py new file mode 100644 index 0000000000000000000000000000000000000000..e06ad2b72d35dcf814b110112cd882cb4b4cc616 --- /dev/null +++ b/huggingface_hub/inference/_generated/types/translation.py @@ -0,0 +1,46 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Dict, Literal, Optional + +from .base import BaseInferenceType + + +TranslationGenerationTruncationStrategy = Literal["do_not_truncate", "longest_first", "only_first", "only_second"] + + +@dataclass +class TranslationGenerationParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Text2text Generation + """ + + clean_up_tokenization_spaces: Optional[bool] = None + """Whether to clean up the potential extra spaces in the text output.""" + generate_parameters: Optional[Dict[str, Any]] = None + """Additional parametrization of the text generation algorithm""" + truncation: Optional["TranslationGenerationTruncationStrategy"] = None + """The truncation strategy to use""" + + +@dataclass +class TranslationInput(BaseInferenceType): + """Inputs for Translation inference + Inputs for Text2text Generation inference + """ + + inputs: str + """The input text data""" + parameters: Optional[TranslationGenerationParameters] = None + """Additional inference parameters""" + + +@dataclass +class TranslationOutput(BaseInferenceType): + """Outputs of inference for the Translation task""" + + translation_text: str + """The translated text.""" diff --git a/huggingface_hub/inference/_generated/types/video_classification.py b/huggingface_hub/inference/_generated/types/video_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5a9d55a81fab6fc71e1226e7776a7a68ee688f --- /dev/null +++ b/huggingface_hub/inference/_generated/types/video_classification.py @@ -0,0 +1,47 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Literal, Optional + +from .base import BaseInferenceType + + +ClassificationOutputTransform = Literal["sigmoid", "softmax", "none"] + + +@dataclass +class VideoClassificationParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Video Classification + """ + + frame_sampling_rate: Optional[int] = None + """The sampling rate used to select frames from the video.""" + function_to_apply: Optional["ClassificationOutputTransform"] = None + num_frames: Optional[int] = None + """The number of sampled frames to consider for classification.""" + top_k: Optional[int] = None + """When specified, limits the output to the top K most probable classes.""" + + +@dataclass +class VideoClassificationInput(BaseInferenceType): + """Inputs for Video Classification inference""" + + inputs: Any + """The input video data""" + parameters: Optional[VideoClassificationParameters] = None + """Additional inference parameters""" + + +@dataclass +class VideoClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Video Classification task""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/huggingface_hub/inference/_generated/types/visual_question_answering.py b/huggingface_hub/inference/_generated/types/visual_question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..6ab7c14d8ab032c2e9bf24c835520182cb1b5e5f --- /dev/null +++ b/huggingface_hub/inference/_generated/types/visual_question_answering.py @@ -0,0 +1,53 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Optional + +from .base import BaseInferenceType + + +@dataclass +class VisualQuestionAnsweringInputData(BaseInferenceType): + """One (image, question) pair to answer""" + + image: Any + """The image.""" + question: Any + """The question to answer based on the image.""" + + +@dataclass +class VisualQuestionAnsweringParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Visual Question Answering + """ + + top_k: Optional[int] = None + """The number of answers to return (will be chosen by order of likelihood). Note that we + return less than topk answers if there are not enough options available within the + context. + """ + + +@dataclass +class VisualQuestionAnsweringInput(BaseInferenceType): + """Inputs for Visual Question Answering inference""" + + inputs: VisualQuestionAnsweringInputData + """One (image, question) pair to answer""" + parameters: Optional[VisualQuestionAnsweringParameters] = None + """Additional inference parameters""" + + +@dataclass +class VisualQuestionAnsweringOutputElement(BaseInferenceType): + """Outputs of inference for the Visual Question Answering task""" + + label: Any + score: float + """The associated score / probability""" + answer: Optional[str] = None + """The answer to the question""" diff --git a/huggingface_hub/inference/_generated/types/zero_shot_classification.py b/huggingface_hub/inference/_generated/types/zero_shot_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..6c55ebf218ca3314993aacd7eaa8c1910b5ab63e --- /dev/null +++ b/huggingface_hub/inference/_generated/types/zero_shot_classification.py @@ -0,0 +1,56 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import List, Optional + +from .base import BaseInferenceType + + +@dataclass +class ZeroShotClassificationInputData(BaseInferenceType): + """The input text data, with candidate labels""" + + candidate_labels: List[str] + """The set of possible class labels to classify the text into.""" + text: str + """The text to classify""" + + +@dataclass +class ZeroShotClassificationParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Zero Shot Classification + """ + + hypothesis_template: Optional[str] = None + """The sentence used in conjunction with candidateLabels to attempt the text classification + by replacing the placeholder with the candidate labels. + """ + multi_label: Optional[bool] = None + """Whether multiple candidate labels can be true. If false, the scores are normalized such + that the sum of the label likelihoods for each sequence is 1. If true, the labels are + considered independent and probabilities are normalized for each candidate. + """ + + +@dataclass +class ZeroShotClassificationInput(BaseInferenceType): + """Inputs for Zero Shot Classification inference""" + + inputs: ZeroShotClassificationInputData + """The input text data, with candidate labels""" + parameters: Optional[ZeroShotClassificationParameters] = None + """Additional inference parameters""" + + +@dataclass +class ZeroShotClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Zero Shot Classification task""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/huggingface_hub/inference/_generated/types/zero_shot_image_classification.py b/huggingface_hub/inference/_generated/types/zero_shot_image_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..1d635187d7ed2f92eb239dc1e4ee4754394dad4c --- /dev/null +++ b/huggingface_hub/inference/_generated/types/zero_shot_image_classification.py @@ -0,0 +1,51 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, List, Optional + +from .base import BaseInferenceType + + +@dataclass +class ZeroShotImageClassificationInputData(BaseInferenceType): + """The input image data, with candidate labels""" + + candidate_labels: List[str] + """The candidate labels for this image""" + image: Any + """The image data to classify""" + + +@dataclass +class ZeroShotImageClassificationParameters(BaseInferenceType): + """Additional inference parameters + Additional inference parameters for Zero Shot Image Classification + """ + + hypothesis_template: Optional[str] = None + """The sentence used in conjunction with candidateLabels to attempt the text classification + by replacing the placeholder with the candidate labels. + """ + + +@dataclass +class ZeroShotImageClassificationInput(BaseInferenceType): + """Inputs for Zero Shot Image Classification inference""" + + inputs: ZeroShotImageClassificationInputData + """The input image data, with candidate labels""" + parameters: Optional[ZeroShotImageClassificationParameters] = None + """Additional inference parameters""" + + +@dataclass +class ZeroShotImageClassificationOutputElement(BaseInferenceType): + """Outputs of inference for the Zero Shot Image Classification task""" + + label: str + """The predicted class label.""" + score: float + """The corresponding probability.""" diff --git a/huggingface_hub/inference/_generated/types/zero_shot_object_detection.py b/huggingface_hub/inference/_generated/types/zero_shot_object_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..42a21568c9c652eb307cf2bd44ee9aa06ab4df7b --- /dev/null +++ b/huggingface_hub/inference/_generated/types/zero_shot_object_detection.py @@ -0,0 +1,55 @@ +# Inference code generated from the JSON schema spec in @huggingface/tasks. +# +# See: +# - script: https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/scripts/inference-codegen.ts +# - specs: https://github.com/huggingface/huggingface.js/tree/main/packages/tasks/src/tasks. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from .base import BaseInferenceType + + +@dataclass +class ZeroShotObjectDetectionInputData(BaseInferenceType): + """The input image data, with candidate labels""" + + candidate_labels: List[str] + """The candidate labels for this image""" + image: Any + """The image data to generate bounding boxes from""" + + +@dataclass +class ZeroShotObjectDetectionInput(BaseInferenceType): + """Inputs for Zero Shot Object Detection inference""" + + inputs: ZeroShotObjectDetectionInputData + """The input image data, with candidate labels""" + parameters: Optional[Dict[str, Any]] = None + """Additional inference parameters""" + + +@dataclass +class ZeroShotObjectDetectionBoundingBox(BaseInferenceType): + """The predicted bounding box. Coordinates are relative to the top left corner of the input + image. + """ + + xmax: int + xmin: int + ymax: int + ymin: int + + +@dataclass +class ZeroShotObjectDetectionOutputElement(BaseInferenceType): + """Outputs of inference for the Zero Shot Object Detection task""" + + box: ZeroShotObjectDetectionBoundingBox + """The predicted bounding box. Coordinates are relative to the top left corner of the input + image. + """ + label: str + """A candidate label""" + score: float + """The associated score / probability""" diff --git a/huggingface_hub/inference/_templating.py b/huggingface_hub/inference/_templating.py new file mode 100644 index 0000000000000000000000000000000000000000..954b2039084f5a0a88e57297d9930cfc63f2c366 --- /dev/null +++ b/huggingface_hub/inference/_templating.py @@ -0,0 +1,102 @@ +from functools import lru_cache +from typing import Callable, Dict, List, Optional, Union + +from ..errors import HfHubHTTPError, RepositoryNotFoundError, TemplateError +from ..utils import is_minijinja_available + + +def _import_minijinja(): + if not is_minijinja_available(): + raise ImportError("Cannot render template. Please install minijinja using `pip install minijinja`.") + import minijinja # noqa: F401 + + return minijinja + + +def render_chat_prompt( + *, + model_id: str, + messages: List[Dict[str, str]], + token: Union[str, bool, None] = None, + add_generation_prompt: bool = True, + **kwargs, +) -> str: + """Render a chat prompt using a model's chat template. + + Args: + model_id (`str`): + The model id. + messages (`List[Dict[str, str]]`): + The list of messages to render. + token (`str` or `bool`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Returns: + `str`: The rendered chat prompt. + + Raises: + `TemplateError`: If there's any issue while fetching, compiling or rendering the chat template. + """ + minijinja = _import_minijinja() + template = _fetch_and_compile_template(model_id=model_id, token=token) + + try: + return template(messages=messages, add_generation_prompt=add_generation_prompt, **kwargs) + except minijinja.TemplateError as e: + raise TemplateError(f"Error while trying to render chat prompt for model '{model_id}': {e}") from e + + +@lru_cache # TODO: lru_cache for raised exceptions +def _fetch_and_compile_template(*, model_id: str, token: Union[str, None]) -> Callable: + """Fetch and compile a model's chat template. + + Method is cached to avoid fetching the same model's config multiple times. + + Args: + model_id (`str`): + The model id. + token (`str` or `bool`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Returns: + `Callable`: A callable that takes a list of messages and returns the rendered chat prompt. + """ + from huggingface_hub.hf_api import HfApi + + minijinja = _import_minijinja() + + # 1. fetch config from API + try: + config = HfApi(token=token).model_info(model_id).config + except RepositoryNotFoundError as e: + raise TemplateError(f"Cannot render chat template: model '{model_id}' not found.") from e + except HfHubHTTPError as e: + raise TemplateError(f"Error while trying to fetch chat template for model '{model_id}': {e}") from e + + # 2. check config validity + if config is None: + raise TemplateError(f"Config not found for model '{model_id}'.") + tokenizer_config = config.get("tokenizer_config") + if tokenizer_config is None: + raise TemplateError(f"Tokenizer config not found for model '{model_id}'.") + if tokenizer_config.get("chat_template") is None: + raise TemplateError(f"Chat template not found in tokenizer_config for model '{model_id}'.") + chat_template = tokenizer_config["chat_template"] + if not isinstance(chat_template, str): + raise TemplateError(f"Chat template must be a string, not '{type(chat_template)}' (model: {model_id}).") + + special_tokens: Dict[str, Optional[str]] = {} + for key, value in tokenizer_config.items(): + if "token" in key: + if isinstance(value, str): + special_tokens[key] = value + elif isinstance(value, dict) and value.get("__type") == "AddedToken": + special_tokens[key] = value.get("content") + + # 3. compile template and return + env = minijinja.Environment() + try: + env.add_template("chat_template", chat_template) + except minijinja.TemplateError as e: + raise TemplateError(f"Error while trying to compile chat template for model '{model_id}': {e}") from e + return lambda **kwargs: env.render_template("chat_template", **kwargs, **special_tokens) diff --git a/huggingface_hub/inference_api.py b/huggingface_hub/inference_api.py new file mode 100644 index 0000000000000000000000000000000000000000..f895fcc61c3867838b013ecd3f6789cbc010b5b3 --- /dev/null +++ b/huggingface_hub/inference_api.py @@ -0,0 +1,217 @@ +import io +from typing import Any, Dict, List, Optional, Union + +from . import constants +from .hf_api import HfApi +from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args +from .utils._deprecation import _deprecate_method + + +logger = logging.get_logger(__name__) + + +ALL_TASKS = [ + # NLP + "text-classification", + "token-classification", + "table-question-answering", + "question-answering", + "zero-shot-classification", + "translation", + "summarization", + "conversational", + "feature-extraction", + "text-generation", + "text2text-generation", + "fill-mask", + "sentence-similarity", + # Audio + "text-to-speech", + "automatic-speech-recognition", + "audio-to-audio", + "audio-classification", + "voice-activity-detection", + # Computer vision + "image-classification", + "object-detection", + "image-segmentation", + "text-to-image", + "image-to-image", + # Others + "tabular-classification", + "tabular-regression", +] + + +class InferenceApi: + """Client to configure requests and make calls to the HuggingFace Inference API. + + Example: + + ```python + >>> from huggingface_hub.inference_api import InferenceApi + + >>> # Mask-fill example + >>> inference = InferenceApi("bert-base-uncased") + >>> inference(inputs="The goal of life is [MASK].") + [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}] + + >>> # Question Answering example + >>> inference = InferenceApi("deepset/roberta-base-squad2") + >>> inputs = { + ... "question": "What's my name?", + ... "context": "My name is Clara and I live in Berkeley.", + ... } + >>> inference(inputs) + {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'} + + >>> # Zero-shot example + >>> inference = InferenceApi("typeform/distilbert-base-uncased-mnli") + >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!" + >>> params = {"candidate_labels": ["refund", "legal", "faq"]} + >>> inference(inputs, params) + {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]} + + >>> # Overriding configured task + >>> inference = InferenceApi("bert-base-uncased", task="feature-extraction") + + >>> # Text-to-image + >>> inference = InferenceApi("stabilityai/stable-diffusion-2-1") + >>> inference("cat") ++ + >>> # Return as raw response to parse the output yourself + >>> inference = InferenceApi("mio/amadeus") + >>> response = inference("hello world", raw_response=True) + >>> response.headers + {"Content-Type": "audio/flac", ...} + >>> response.content # raw bytes from server + b'(...)' + ``` + """ + + @validate_hf_hub_args + @_deprecate_method( + version="1.0", + message=( + "`InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out" + " this guide to learn how to convert your script to use it:" + " https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client." + ), + ) + def __init__( + self, + repo_id: str, + task: Optional[str] = None, + token: Optional[str] = None, + gpu: bool = False, + ): + """Inits headers and API call information. + + Args: + repo_id (``str``): + Id of repository (e.g. `user/bert-base-uncased`). + task (``str``, `optional`, defaults ``None``): + Whether to force a task instead of using task specified in the + repository. + token (`str`, `optional`): + The API token to use as HTTP bearer authorization. This is not + the authentication token. You can find the token in + https://huggingface.co/settings/token. Alternatively, you can + find both your organizations and personal API tokens using + `HfApi().whoami(token)`. + gpu (`bool`, `optional`, defaults `False`): + Whether to use GPU instead of CPU for inference(requires Startup + plan at least). + """ + self.options = {"wait_for_model": True, "use_gpu": gpu} + self.headers = build_hf_headers(token=token) + + # Configure task + model_info = HfApi(token=token).model_info(repo_id=repo_id) + if not model_info.pipeline_tag and not task: + raise ValueError( + "Task not specified in the repository. Please add it to the model card" + " using pipeline_tag" + " (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)" + ) + + if task and task != model_info.pipeline_tag: + if task not in ALL_TASKS: + raise ValueError(f"Invalid task {task}. Make sure it's valid.") + + logger.warning( + "You're using a different task than the one specified in the" + " repository. Be sure to know what you're doing :)" + ) + self.task = task + else: + assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None" + self.task = model_info.pipeline_tag + + self.api_url = f"{constants.INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}" + + def __repr__(self): + # Do not add headers to repr to avoid leaking token. + return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})" + + def __call__( + self, + inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None, + params: Optional[Dict] = None, + data: Optional[bytes] = None, + raw_response: bool = False, + ) -> Any: + """Make a call to the Inference API. + + Args: + inputs (`str` or `Dict` or `List[str]` or `List[List[str]]`, *optional*): + Inputs for the prediction. + params (`Dict`, *optional*): + Additional parameters for the models. Will be sent as `parameters` in the + payload. + data (`bytes`, *optional*): + Bytes content of the request. In this case, leave `inputs` and `params` empty. + raw_response (`bool`, defaults to `False`): + If `True`, the raw `Response` object is returned. You can parse its content + as preferred. By default, the content is parsed into a more practical format + (json dictionary or PIL Image for example). + """ + # Build payload + payload: Dict[str, Any] = { + "options": self.options, + } + if inputs: + payload["inputs"] = inputs + if params: + payload["parameters"] = params + + # Make API call + response = get_session().post(self.api_url, headers=self.headers, json=payload, data=data) + + # Let the user handle the response + if raw_response: + return response + + # By default, parse the response for the user. + content_type = response.headers.get("Content-Type") or "" + if content_type.startswith("image"): + if not is_pillow_available(): + raise ImportError( + f"Task '{self.task}' returned as image but Pillow is not installed." + " Please install it (`pip install Pillow`) or pass" + " `raw_response=True` to get the raw `Response` object and parse" + " the image by yourself." + ) + + from PIL import Image + + return Image.open(io.BytesIO(response.content)) + elif content_type == "application/json": + return response.json() + else: + raise NotImplementedError( + f"{content_type} output type is not implemented yet. You can pass" + " `raw_response=True` to get the raw `Response` object and parse the" + " output by yourself." + ) diff --git a/huggingface_hub/keras_mixin.py b/huggingface_hub/keras_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..f5d9edf37af27a061c0f4088c6f7ed87ec8aa962 --- /dev/null +++ b/huggingface_hub/keras_mixin.py @@ -0,0 +1,499 @@ +import collections.abc as collections +import json +import os +import warnings +from functools import wraps +from pathlib import Path +from shutil import copytree +from typing import Any, Dict, List, Optional, Union + +from huggingface_hub import ModelHubMixin, snapshot_download +from huggingface_hub.utils import ( + get_tf_version, + is_graphviz_available, + is_pydot_available, + is_tf_available, + yaml_dump, +) + +from . import constants +from .hf_api import HfApi +from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args +from .utils._typing import CallableT + + +logger = logging.get_logger(__name__) + +keras = None +if is_tf_available(): + # Depending on which version of TensorFlow is installed, we need to import + # keras from the correct location. + # See https://github.com/tensorflow/tensorflow/releases/tag/v2.16.1. + # Note: saving a keras model only works with Keras<3.0. + try: + import tf_keras as keras # type: ignore + except ImportError: + import tensorflow as tf # type: ignore + + keras = tf.keras + + +def _requires_keras_2_model(fn: CallableT) -> CallableT: + # Wrapper to raise if user tries to save a Keras 3.x model + @wraps(fn) + def _inner(model, *args, **kwargs): + if not hasattr(model, "history"): # hacky way to check if model is Keras 2.x + raise NotImplementedError( + f"Cannot use '{fn.__name__}': Keras 3.x is not supported." + " Please save models manually and upload them using `upload_folder` or `huggingface-cli upload`." + ) + return fn(model, *args, **kwargs) + + return _inner # type: ignore [return-value] + + +def _flatten_dict(dictionary, parent_key=""): + """Flatten a nested dictionary. + Reference: https://stackoverflow.com/a/6027615/10319735 + + Args: + dictionary (`dict`): + The nested dictionary to be flattened. + parent_key (`str`): + The parent key to be prefixed to the children keys. + Necessary for recursing over the nested dictionary. + + Returns: + The flattened dictionary. + """ + items = [] + for key, value in dictionary.items(): + new_key = f"{parent_key}.{key}" if parent_key else key + if isinstance(value, collections.MutableMapping): + items.extend( + _flatten_dict( + value, + new_key, + ).items() + ) + else: + items.append((new_key, value)) + return dict(items) + + +def _create_hyperparameter_table(model): + """Parse hyperparameter dictionary into a markdown table.""" + table = None + if model.optimizer is not None: + optimizer_params = model.optimizer.get_config() + # flatten the configuration + optimizer_params = _flatten_dict(optimizer_params) + optimizer_params["training_precision"] = keras.mixed_precision.global_policy().name + table = "| Hyperparameters | Value |\n| :-- | :-- |\n" + for key, value in optimizer_params.items(): + table += f"| {key} | {value} |\n" + return table + + +def _plot_network(model, save_directory): + keras.utils.plot_model( + model, + to_file=f"{save_directory}/model.png", + show_shapes=False, + show_dtype=False, + show_layer_names=True, + rankdir="TB", + expand_nested=False, + dpi=96, + layer_range=None, + ) + + +def _create_model_card( + model, + repo_dir: Path, + plot_model: bool = True, + metadata: Optional[dict] = None, +): + """ + Creates a model card for the repository. + + Do not overwrite an existing README.md file. + """ + readme_path = repo_dir / "README.md" + if readme_path.exists(): + return + + hyperparameters = _create_hyperparameter_table(model) + if plot_model and is_graphviz_available() and is_pydot_available(): + _plot_network(model, repo_dir) + if metadata is None: + metadata = {} + metadata["library_name"] = "keras" + model_card: str = "---\n" + model_card += yaml_dump(metadata, default_flow_style=False) + model_card += "---\n" + model_card += "\n## Model description\n\nMore information needed\n" + model_card += "\n## Intended uses & limitations\n\nMore information needed\n" + model_card += "\n## Training and evaluation data\n\nMore information needed\n" + if hyperparameters is not None: + model_card += "\n## Training procedure\n" + model_card += "\n### Training hyperparameters\n" + model_card += "\nThe following hyperparameters were used during training:\n\n" + model_card += hyperparameters + model_card += "\n" + if plot_model and os.path.exists(f"{repo_dir}/model.png"): + model_card += "\n ## Model Plot\n" + model_card += "\n " + model_card += "\n" + + readme_path.write_text(model_card) + + +@_requires_keras_2_model +def save_pretrained_keras( + model, + save_directory: Union[str, Path], + config: Optional[Dict[str, Any]] = None, + include_optimizer: bool = False, + plot_model: bool = True, + tags: Optional[Union[list, str]] = None, + **model_save_kwargs, +): + """ + Saves a Keras model to save_directory in SavedModel format. Use this if + you're using the Functional or Sequential APIs. + + Args: + model (`Keras.Model`): + The [Keras + model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) + you'd like to save. The model must be compiled and built. + save_directory (`str` or `Path`): + Specify directory in which you want to save the Keras model. + config (`dict`, *optional*): + Configuration object to be saved alongside the model weights. + include_optimizer(`bool`, *optional*, defaults to `False`): + Whether or not to include optimizer in serialization. + plot_model (`bool`, *optional*, defaults to `True`): + Setting this to `True` will plot the model and put it in the model + card. Requires graphviz and pydot to be installed. + tags (Union[`str`,`list`], *optional*): + List of tags that are related to model or string of a single tag. See example tags + [here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1). + model_save_kwargs(`dict`, *optional*): + model_save_kwargs will be passed to + [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model). + """ + if keras is None: + raise ImportError("Called a Tensorflow-specific function but could not import it.") + + if not model.built: + raise ValueError("Model should be built before trying to save") + + save_directory = Path(save_directory) + save_directory.mkdir(parents=True, exist_ok=True) + + # saving config + if config: + if not isinstance(config, dict): + raise RuntimeError(f"Provided config to save_pretrained_keras should be a dict. Got: '{type(config)}'") + + with (save_directory / constants.CONFIG_NAME).open("w") as f: + json.dump(config, f) + + metadata = {} + if isinstance(tags, list): + metadata["tags"] = tags + elif isinstance(tags, str): + metadata["tags"] = [tags] + + task_name = model_save_kwargs.pop("task_name", None) + if task_name is not None: + warnings.warn( + "`task_name` input argument is deprecated. Pass `tags` instead.", + FutureWarning, + ) + if "tags" in metadata: + metadata["tags"].append(task_name) + else: + metadata["tags"] = [task_name] + + if model.history is not None: + if model.history.history != {}: + path = save_directory / "history.json" + if path.exists(): + warnings.warn( + "`history.json` file already exists, it will be overwritten by the history of this version.", + UserWarning, + ) + with path.open("w", encoding="utf-8") as f: + json.dump(model.history.history, f, indent=2, sort_keys=True) + + _create_model_card(model, save_directory, plot_model, metadata) + keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs) + + +def from_pretrained_keras(*args, **kwargs) -> "KerasModelHubMixin": + r""" + Instantiate a pretrained Keras model from a pre-trained model from the Hub. + The model is expected to be in `SavedModel` format. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + - A string, the `model id` of a pretrained model hosted inside a + model repo on huggingface.co. Valid model ids can be located + at the root-level, like `bert-base-uncased`, or namespaced + under a user or organization name, like + `dbmdz/bert-base-german-cased`. + - You can add `revision` by appending `@` at the end of model_id + simply like this: `dbmdz/bert-base-german-cased@main` Revision + is the specific model version to use. It can be a branch name, + a tag name, or a commit id, since we use a git-based system + for storing models and other artifacts on huggingface.co, so + `revision` can be any identifier allowed by git. + - A path to a `directory` containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., + `./my_model_directory/`. + - `None` if you are both providing the configuration and state + dictionary (resp. with keyword arguments `config` and + `state_dict`). + force_download (`bool`, *optional*, defaults to `False`): + Whether to force the (re-)download of the model weights and + configuration files, overriding the cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., + `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The + proxies are used on each request. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If + `True`, will use the token generated when running `transformers-cli + login` (stored in `~/.huggingface`). + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model + configuration should be cached if the standard cache should not be + used. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether to only look at local files (i.e., do not try to download + the model). + model_kwargs (`Dict`, *optional*): + model_kwargs will be passed to the model during initialization + +View Model Plot
\n" + path_to_plot = "./model.png" + model_card += f"\n\n" + model_card += "\n+ + Passing `token=True` is required when you want to use a private + model. + + + """ + return KerasModelHubMixin.from_pretrained(*args, **kwargs) + + +@validate_hf_hub_args +@_requires_keras_2_model +def push_to_hub_keras( + model, + repo_id: str, + *, + config: Optional[dict] = None, + commit_message: str = "Push Keras model using huggingface_hub.", + private: bool = False, + api_endpoint: Optional[str] = None, + token: Optional[str] = None, + branch: Optional[str] = None, + create_pr: Optional[bool] = None, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + delete_patterns: Optional[Union[List[str], str]] = None, + log_dir: Optional[str] = None, + include_optimizer: bool = False, + tags: Optional[Union[list, str]] = None, + plot_model: bool = True, + **model_save_kwargs, +): + """ + Upload model checkpoint to the Hub. + + Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use + `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more + details. + + Args: + model (`Keras.Model`): + The [Keras model](`https://www.tensorflow.org/api_docs/python/tf/keras/Model`) you'd like to push to the + Hub. The model must be compiled and built. + repo_id (`str`): + ID of the repository to push to (example: `"username/my-model"`). + commit_message (`str`, *optional*, defaults to "Add Keras model"): + Message to commit while pushing. + private (`bool`, *optional*, defaults to `False`): + Whether the repository created should be private. + api_endpoint (`str`, *optional*): + The API endpoint to use when pushing the model to the hub. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. If + not set, will use the token set when logging in with + `huggingface-cli login` (stored in `~/.huggingface`). + branch (`str`, *optional*): + The git branch on which to push the model. This defaults to + the default branch as specified in your repository, which + defaults to `"main"`. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `branch` with that commit. + Defaults to `False`. + config (`dict`, *optional*): + Configuration object to be saved alongside the model weights. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are pushed. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not pushed. + delete_patterns (`List[str]` or `str`, *optional*): + If provided, remote files matching any of the patterns will be deleted from the repo. + log_dir (`str`, *optional*): + TensorBoard logging directory to be pushed. The Hub automatically + hosts and displays a TensorBoard instance if log files are included + in the repository. + include_optimizer (`bool`, *optional*, defaults to `False`): + Whether or not to include optimizer during serialization. + tags (Union[`list`, `str`], *optional*): + List of tags that are related to model or string of a single tag. See example tags + [here](https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1). + plot_model (`bool`, *optional*, defaults to `True`): + Setting this to `True` will plot the model and put it in the model + card. Requires graphviz and pydot to be installed. + model_save_kwargs(`dict`, *optional*): + model_save_kwargs will be passed to + [`tf.keras.models.save_model()`](https://www.tensorflow.org/api_docs/python/tf/keras/models/save_model). + + Returns: + The url of the commit of your model in the given repository. + """ + api = HfApi(endpoint=api_endpoint) + repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id + + # Push the files to the repo in a single commit + with SoftTemporaryDirectory() as tmp: + saved_path = Path(tmp) / repo_id + save_pretrained_keras( + model, + saved_path, + config=config, + include_optimizer=include_optimizer, + tags=tags, + plot_model=plot_model, + **model_save_kwargs, + ) + + # If `log_dir` provided, delete remote logs and upload new ones + if log_dir is not None: + delete_patterns = ( + [] + if delete_patterns is None + else ( + [delete_patterns] # convert `delete_patterns` to a list + if isinstance(delete_patterns, str) + else delete_patterns + ) + ) + delete_patterns.append("logs/*") + copytree(log_dir, saved_path / "logs") + + return api.upload_folder( + repo_type="model", + repo_id=repo_id, + folder_path=saved_path, + commit_message=commit_message, + token=token, + revision=branch, + create_pr=create_pr, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + delete_patterns=delete_patterns, + ) + + +class KerasModelHubMixin(ModelHubMixin): + """ + Implementation of [`ModelHubMixin`] to provide model Hub upload/download + capabilities to Keras models. + + + ```python + >>> import tensorflow as tf + >>> from huggingface_hub import KerasModelHubMixin + + + >>> class MyModel(tf.keras.Model, KerasModelHubMixin): + ... def __init__(self, **kwargs): + ... super().__init__() + ... self.config = kwargs.pop("config", None) + ... self.dummy_inputs = ... + ... self.layer = ... + + ... def call(self, *args): + ... return ... + + + >>> # Initialize and compile the model as you normally would + >>> model = MyModel() + >>> model.compile(...) + >>> # Build the graph by training it or passing dummy inputs + >>> _ = model(model.dummy_inputs) + >>> # Save model weights to local directory + >>> model.save_pretrained("my-awesome-model") + >>> # Push model weights to the Hub + >>> model.push_to_hub("my-awesome-model") + >>> # Download and initialize weights from the Hub + >>> model = MyModel.from_pretrained("username/super-cool-model") + ``` + """ + + def _save_pretrained(self, save_directory): + save_pretrained_keras(self, save_directory) + + @classmethod + def _from_pretrained( + cls, + model_id, + revision, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + token, + config: Optional[Dict[str, Any]] = None, + **model_kwargs, + ): + """Here we just call [`from_pretrained_keras`] function so both the mixin and + functional APIs stay in sync. + + TODO - Some args above aren't used since we are calling + snapshot_download instead of hf_hub_download. + """ + if keras is None: + raise ImportError("Called a TensorFlow-specific function but could not import it.") + + # Root is either a local filepath matching model_id or a cached snapshot + if not os.path.isdir(model_id): + storage_folder = snapshot_download( + repo_id=model_id, + revision=revision, + cache_dir=cache_dir, + library_name="keras", + library_version=get_tf_version(), + ) + else: + storage_folder = model_id + + # TODO: change this in a future PR. We are not returning a KerasModelHubMixin instance here... + model = keras.models.load_model(storage_folder) + + # For now, we add a new attribute, config, to store the config loaded from the hub/a local dir. + model.config = config + + return model diff --git a/huggingface_hub/lfs.py b/huggingface_hub/lfs.py new file mode 100644 index 0000000000000000000000000000000000000000..2ea852601e8c8dd653f5cf70ea21b5f47fa195a5 --- /dev/null +++ b/huggingface_hub/lfs.py @@ -0,0 +1,463 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Git LFS related type definitions and utilities""" + +import inspect +import io +import re +import warnings +from dataclasses import dataclass +from math import ceil +from os.path import getsize +from pathlib import Path +from typing import TYPE_CHECKING, BinaryIO, Dict, Iterable, List, Optional, Tuple, TypedDict +from urllib.parse import unquote + +from huggingface_hub import constants + +from .utils import ( + build_hf_headers, + fix_hf_endpoint_in_url, + get_session, + hf_raise_for_status, + http_backoff, + logging, + tqdm, + validate_hf_hub_args, +) +from .utils._lfs import SliceFileObj +from .utils.sha import sha256, sha_fileobj + + +if TYPE_CHECKING: + from ._commit_api import CommitOperationAdd + +logger = logging.get_logger(__name__) + +OID_REGEX = re.compile(r"^[0-9a-f]{40}$") + +LFS_MULTIPART_UPLOAD_COMMAND = "lfs-multipart-upload" + +LFS_HEADERS = { + "Accept": "application/vnd.git-lfs+json", + "Content-Type": "application/vnd.git-lfs+json", +} + + +@dataclass +class UploadInfo: + """ + Dataclass holding required information to determine whether a blob + should be uploaded to the hub using the LFS protocol or the regular protocol + + Args: + sha256 (`bytes`): + SHA256 hash of the blob + size (`int`): + Size in bytes of the blob + sample (`bytes`): + First 512 bytes of the blob + """ + + sha256: bytes + size: int + sample: bytes + + @classmethod + def from_path(cls, path: str): + size = getsize(path) + with io.open(path, "rb") as file: + sample = file.peek(512)[:512] + sha = sha_fileobj(file) + return cls(size=size, sha256=sha, sample=sample) + + @classmethod + def from_bytes(cls, data: bytes): + sha = sha256(data).digest() + return cls(size=len(data), sample=data[:512], sha256=sha) + + @classmethod + def from_fileobj(cls, fileobj: BinaryIO): + sample = fileobj.read(512) + fileobj.seek(0, io.SEEK_SET) + sha = sha_fileobj(fileobj) + size = fileobj.tell() + fileobj.seek(0, io.SEEK_SET) + return cls(size=size, sha256=sha, sample=sample) + + +@validate_hf_hub_args +def post_lfs_batch_info( + upload_infos: Iterable[UploadInfo], + token: Optional[str], + repo_type: str, + repo_id: str, + revision: Optional[str] = None, + endpoint: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, +) -> Tuple[List[dict], List[dict]]: + """ + Requests the LFS batch endpoint to retrieve upload instructions + + Learn more: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md + + Args: + upload_infos (`Iterable` of `UploadInfo`): + `UploadInfo` for the files that are being uploaded, typically obtained + from `CommitOperationAdd.upload_info` + repo_type (`str`): + Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`. + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The git revision to upload to. + headers (`dict`, *optional*): + Additional headers to include in the request + + Returns: + `LfsBatchInfo`: 2-tuple: + - First element is the list of upload instructions from the server + - Second element is an list of errors, if any + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If an argument is invalid or the server response is malformed. + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + If the server returned an error. + """ + endpoint = endpoint if endpoint is not None else constants.ENDPOINT + url_prefix = "" + if repo_type in constants.REPO_TYPES_URL_PREFIXES: + url_prefix = constants.REPO_TYPES_URL_PREFIXES[repo_type] + batch_url = f"{endpoint}/{url_prefix}{repo_id}.git/info/lfs/objects/batch" + payload: Dict = { + "operation": "upload", + "transfers": ["basic", "multipart"], + "objects": [ + { + "oid": upload.sha256.hex(), + "size": upload.size, + } + for upload in upload_infos + ], + "hash_algo": "sha256", + } + if revision is not None: + payload["ref"] = {"name": unquote(revision)} # revision has been previously 'quoted' + + headers = { + **LFS_HEADERS, + **build_hf_headers(token=token), + **(headers or {}), + } + resp = get_session().post(batch_url, headers=headers, json=payload) + hf_raise_for_status(resp) + batch_info = resp.json() + + objects = batch_info.get("objects", None) + if not isinstance(objects, list): + raise ValueError("Malformed response from server") + + return ( + [_validate_batch_actions(obj) for obj in objects if "error" not in obj], + [_validate_batch_error(obj) for obj in objects if "error" in obj], + ) + + +class PayloadPartT(TypedDict): + partNumber: int + etag: str + + +class CompletionPayloadT(TypedDict): + """Payload that will be sent to the Hub when uploading multi-part.""" + + oid: str + parts: List[PayloadPartT] + + +def lfs_upload( + operation: "CommitOperationAdd", + lfs_batch_action: Dict, + token: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + endpoint: Optional[str] = None, +) -> None: + """ + Handles uploading a given object to the Hub with the LFS protocol. + + Can be a No-op if the content of the file is already present on the hub large file storage. + + Args: + operation (`CommitOperationAdd`): + The add operation triggering this upload. + lfs_batch_action (`dict`): + Upload instructions from the LFS batch endpoint for this object. See [`~utils.lfs.post_lfs_batch_info`] for + more details. + headers (`dict`, *optional*): + Headers to include in the request, including authentication and user agent headers. + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `lfs_batch_action` is improperly formatted + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + If the upload resulted in an error + """ + # 0. If LFS file is already present, skip upload + _validate_batch_actions(lfs_batch_action) + actions = lfs_batch_action.get("actions") + if actions is None: + # The file was already uploaded + logger.debug(f"Content of file {operation.path_in_repo} is already present upstream - skipping upload") + return + + # 1. Validate server response (check required keys in dict) + upload_action = lfs_batch_action["actions"]["upload"] + _validate_lfs_action(upload_action) + verify_action = lfs_batch_action["actions"].get("verify") + if verify_action is not None: + _validate_lfs_action(verify_action) + + # 2. Upload file (either single part or multi-part) + header = upload_action.get("header", {}) + chunk_size = header.get("chunk_size") + upload_url = fix_hf_endpoint_in_url(upload_action["href"], endpoint=endpoint) + if chunk_size is not None: + try: + chunk_size = int(chunk_size) + except (ValueError, TypeError): + raise ValueError( + f"Malformed response from LFS batch endpoint: `chunk_size` should be an integer. Got '{chunk_size}'." + ) + _upload_multi_part(operation=operation, header=header, chunk_size=chunk_size, upload_url=upload_url) + else: + _upload_single_part(operation=operation, upload_url=upload_url) + + # 3. Verify upload went well + if verify_action is not None: + _validate_lfs_action(verify_action) + verify_url = fix_hf_endpoint_in_url(verify_action["href"], endpoint) + verify_resp = get_session().post( + verify_url, + headers=build_hf_headers(token=token, headers=headers), + json={"oid": operation.upload_info.sha256.hex(), "size": operation.upload_info.size}, + ) + hf_raise_for_status(verify_resp) + logger.debug(f"{operation.path_in_repo}: Upload successful") + + +def _validate_lfs_action(lfs_action: dict): + """validates response from the LFS batch endpoint""" + if not ( + isinstance(lfs_action.get("href"), str) + and (lfs_action.get("header") is None or isinstance(lfs_action.get("header"), dict)) + ): + raise ValueError("lfs_action is improperly formatted") + return lfs_action + + +def _validate_batch_actions(lfs_batch_actions: dict): + """validates response from the LFS batch endpoint""" + if not (isinstance(lfs_batch_actions.get("oid"), str) and isinstance(lfs_batch_actions.get("size"), int)): + raise ValueError("lfs_batch_actions is improperly formatted") + + upload_action = lfs_batch_actions.get("actions", {}).get("upload") + verify_action = lfs_batch_actions.get("actions", {}).get("verify") + if upload_action is not None: + _validate_lfs_action(upload_action) + if verify_action is not None: + _validate_lfs_action(verify_action) + return lfs_batch_actions + + +def _validate_batch_error(lfs_batch_error: dict): + """validates response from the LFS batch endpoint""" + if not (isinstance(lfs_batch_error.get("oid"), str) and isinstance(lfs_batch_error.get("size"), int)): + raise ValueError("lfs_batch_error is improperly formatted") + error_info = lfs_batch_error.get("error") + if not ( + isinstance(error_info, dict) + and isinstance(error_info.get("message"), str) + and isinstance(error_info.get("code"), int) + ): + raise ValueError("lfs_batch_error is improperly formatted") + return lfs_batch_error + + +def _upload_single_part(operation: "CommitOperationAdd", upload_url: str) -> None: + """ + Uploads `fileobj` as a single PUT HTTP request (basic LFS transfer protocol) + + Args: + upload_url (`str`): + The URL to PUT the file to. + fileobj: + The file-like object holding the data to upload. + + Returns: `requests.Response` + + Raises: + [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + If the upload resulted in an error. + """ + with operation.as_file(with_tqdm=True) as fileobj: + # S3 might raise a transient 500 error -> let's retry if that happens + response = http_backoff("PUT", upload_url, data=fileobj, retry_on_status_codes=(500, 502, 503, 504)) + hf_raise_for_status(response) + + +def _upload_multi_part(operation: "CommitOperationAdd", header: Dict, chunk_size: int, upload_url: str) -> None: + """ + Uploads file using HF multipart LFS transfer protocol. + """ + # 1. Get upload URLs for each part + sorted_parts_urls = _get_sorted_parts_urls(header=header, upload_info=operation.upload_info, chunk_size=chunk_size) + + # 2. Upload parts (either with hf_transfer or in pure Python) + use_hf_transfer = constants.HF_HUB_ENABLE_HF_TRANSFER + if ( + constants.HF_HUB_ENABLE_HF_TRANSFER + and not isinstance(operation.path_or_fileobj, str) + and not isinstance(operation.path_or_fileobj, Path) + ): + warnings.warn( + "hf_transfer is enabled but does not support uploading from bytes or BinaryIO, falling back to regular" + " upload" + ) + use_hf_transfer = False + + response_headers = ( + _upload_parts_hf_transfer(operation=operation, sorted_parts_urls=sorted_parts_urls, chunk_size=chunk_size) + if use_hf_transfer + else _upload_parts_iteratively(operation=operation, sorted_parts_urls=sorted_parts_urls, chunk_size=chunk_size) + ) + + # 3. Send completion request + completion_res = get_session().post( + upload_url, + json=_get_completion_payload(response_headers, operation.upload_info.sha256.hex()), + headers=LFS_HEADERS, + ) + hf_raise_for_status(completion_res) + + +def _get_sorted_parts_urls(header: Dict, upload_info: UploadInfo, chunk_size: int) -> List[str]: + sorted_part_upload_urls = [ + upload_url + for _, upload_url in sorted( + [ + (int(part_num, 10), upload_url) + for part_num, upload_url in header.items() + if part_num.isdigit() and len(part_num) > 0 + ], + key=lambda t: t[0], + ) + ] + num_parts = len(sorted_part_upload_urls) + if num_parts != ceil(upload_info.size / chunk_size): + raise ValueError("Invalid server response to upload large LFS file") + return sorted_part_upload_urls + + +def _get_completion_payload(response_headers: List[Dict], oid: str) -> CompletionPayloadT: + parts: List[PayloadPartT] = [] + for part_number, header in enumerate(response_headers): + etag = header.get("etag") + if etag is None or etag == "": + raise ValueError(f"Invalid etag (`{etag}`) returned for part {part_number + 1}") + parts.append( + { + "partNumber": part_number + 1, + "etag": etag, + } + ) + return {"oid": oid, "parts": parts} + + +def _upload_parts_iteratively( + operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int +) -> List[Dict]: + headers = [] + with operation.as_file(with_tqdm=True) as fileobj: + for part_idx, part_upload_url in enumerate(sorted_parts_urls): + with SliceFileObj( + fileobj, + seek_from=chunk_size * part_idx, + read_limit=chunk_size, + ) as fileobj_slice: + # S3 might raise a transient 500 error -> let's retry if that happens + part_upload_res = http_backoff( + "PUT", part_upload_url, data=fileobj_slice, retry_on_status_codes=(500, 502, 503, 504) + ) + hf_raise_for_status(part_upload_res) + headers.append(part_upload_res.headers) + return headers # type: ignore + + +def _upload_parts_hf_transfer( + operation: "CommitOperationAdd", sorted_parts_urls: List[str], chunk_size: int +) -> List[Dict]: + # Upload file using an external Rust-based package. Upload is faster but support less features (no progress bars). + try: + from hf_transfer import multipart_upload + except ImportError: + raise ValueError( + "Fast uploading using 'hf_transfer' is enabled (HF_HUB_ENABLE_HF_TRANSFER=1) but 'hf_transfer' package is" + " not available in your environment. Try `pip install hf_transfer`." + ) + + supports_callback = "callback" in inspect.signature(multipart_upload).parameters + if not supports_callback: + warnings.warn( + "You are using an outdated version of `hf_transfer`. Consider upgrading to latest version to enable progress bars using `pip install -U hf_transfer`." + ) + + total = operation.upload_info.size + desc = operation.path_in_repo + if len(desc) > 40: + desc = f"(β¦){desc[-40:]}" + + # set `disable=None` rather than `disable=False` by default to disable progress bar when no TTY attached + # see https://github.com/huggingface/huggingface_hub/pull/2000 + disable = True if (logger.getEffectiveLevel() == logging.NOTSET) else None + + with tqdm( + unit="B", + unit_scale=True, + total=total, + initial=0, + desc=desc, + disable=disable, + name="huggingface_hub.lfs_upload", + ) as progress: + try: + output = multipart_upload( + file_path=operation.path_or_fileobj, + parts_urls=sorted_parts_urls, + chunk_size=chunk_size, + max_files=128, + parallel_failures=127, # could be removed + max_retries=5, + **({"callback": progress.update} if supports_callback else {}), + ) + except Exception as e: + raise RuntimeError( + "An error occurred while uploading using `hf_transfer`. Consider disabling HF_HUB_ENABLE_HF_TRANSFER for" + " better error handling." + ) from e + if not supports_callback: + progress.update(total) + return output diff --git a/huggingface_hub/repocard.py b/huggingface_hub/repocard.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ae591f40dead7e7d03f36a0dd36546f723c23b --- /dev/null +++ b/huggingface_hub/repocard.py @@ -0,0 +1,829 @@ +import os +import re +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Type, Union + +import requests +import yaml + +from huggingface_hub.file_download import hf_hub_download +from huggingface_hub.hf_api import upload_file +from huggingface_hub.repocard_data import ( + CardData, + DatasetCardData, + EvalResult, + ModelCardData, + SpaceCardData, + eval_results_to_model_index, + model_index_to_eval_results, +) +from huggingface_hub.utils import get_session, is_jinja_available, yaml_dump + +from . import constants +from .errors import EntryNotFoundError +from .utils import SoftTemporaryDirectory, logging, validate_hf_hub_args + + +logger = logging.get_logger(__name__) + + +TEMPLATE_MODELCARD_PATH = Path(__file__).parent / "templates" / "modelcard_template.md" +TEMPLATE_DATASETCARD_PATH = Path(__file__).parent / "templates" / "datasetcard_template.md" + +# exact same regex as in the Hub server. Please keep in sync. +# See https://github.com/huggingface/moon-landing/blob/main/server/lib/ViewMarkdown.ts#L18 +REGEX_YAML_BLOCK = re.compile(r"^(\s*---[\r\n]+)([\S\s]*?)([\r\n]+---(\r\n|\n|$))") + + +class RepoCard: + card_data_class = CardData + default_template_path = TEMPLATE_MODELCARD_PATH + repo_type = "model" + + def __init__(self, content: str, ignore_metadata_errors: bool = False): + """Initialize a RepoCard from string content. The content should be a + Markdown file with a YAML block at the beginning and a Markdown body. + + Args: + content (`str`): The content of the Markdown file. + + Example: + ```python + >>> from huggingface_hub.repocard import RepoCard + >>> text = ''' + ... --- + ... language: en + ... license: mit + ... --- + ... + ... # My repo + ... ''' + >>> card = RepoCard(text) + >>> card.data.to_dict() + {'language': 'en', 'license': 'mit'} + >>> card.text + '\\n# My repo\\n' + + ``` ++ Raises the following error: + + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + when the content of the repo card metadata is not a dictionary. + + + """ + + # Set the content of the RepoCard, as well as underlying .data and .text attributes. + # See the `content` property setter for more details. + self.ignore_metadata_errors = ignore_metadata_errors + self.content = content + + @property + def content(self): + """The content of the RepoCard, including the YAML block and the Markdown body.""" + line_break = _detect_line_ending(self._content) or "\n" + return f"---{line_break}{self.data.to_yaml(line_break=line_break)}{line_break}---{line_break}{self.text}" + + @content.setter + def content(self, content: str): + """Set the content of the RepoCard.""" + self._content = content + + match = REGEX_YAML_BLOCK.search(content) + if match: + # Metadata found in the YAML block + yaml_block = match.group(2) + self.text = content[match.end() :] + data_dict = yaml.safe_load(yaml_block) + + if data_dict is None: + data_dict = {} + + # The YAML block's data should be a dictionary + if not isinstance(data_dict, dict): + raise ValueError("repo card metadata block should be a dict") + else: + # Model card without metadata... create empty metadata + logger.warning("Repo card metadata block was not found. Setting CardData to empty.") + data_dict = {} + self.text = content + + self.data = self.card_data_class(**data_dict, ignore_metadata_errors=self.ignore_metadata_errors) + + def __str__(self): + return self.content + + def save(self, filepath: Union[Path, str]): + r"""Save a RepoCard to a file. + + Args: + filepath (`Union[Path, str]`): Filepath to the markdown file to save. + + Example: + ```python + >>> from huggingface_hub.repocard import RepoCard + >>> card = RepoCard("---\nlanguage: en\n---\n# This is a test repo card") + >>> card.save("/tmp/test.md") + + ``` + """ + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + # Preserve newlines as in the existing file. + with open(filepath, mode="w", newline="", encoding="utf-8") as f: + f.write(str(self)) + + @classmethod + def load( + cls, + repo_id_or_path: Union[str, Path], + repo_type: Optional[str] = None, + token: Optional[str] = None, + ignore_metadata_errors: bool = False, + ): + """Initialize a RepoCard from a Hugging Face Hub repo's README.md or a local filepath. + + Args: + repo_id_or_path (`Union[str, Path]`): + The repo ID associated with a Hugging Face Hub repo or a local filepath. + repo_type (`str`, *optional*): + The type of Hugging Face repo to push to. Defaults to None, which will use use "model". Other options + are "dataset" and "space". Not used when loading from a local filepath. If this is called from a child + class, the default value will be the child class's `repo_type`. + token (`str`, *optional*): + Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token. + ignore_metadata_errors (`str`): + If True, errors while parsing the metadata section will be ignored. Some information might be lost during + the process. Use it at your own risk. + + Returns: + [`huggingface_hub.repocard.RepoCard`]: The RepoCard (or subclass) initialized from the repo's + README.md file or filepath. + + Example: + ```python + >>> from huggingface_hub.repocard import RepoCard + >>> card = RepoCard.load("nateraw/food") + >>> assert card.data.tags == ["generated_from_trainer", "image-classification", "pytorch"] + + ``` + """ + + if Path(repo_id_or_path).exists(): + card_path = Path(repo_id_or_path) + elif isinstance(repo_id_or_path, str): + card_path = Path( + hf_hub_download( + repo_id_or_path, + constants.REPOCARD_NAME, + repo_type=repo_type or cls.repo_type, + token=token, + ) + ) + else: + raise ValueError(f"Cannot load RepoCard: path not found on disk ({repo_id_or_path}).") + + # Preserve newlines in the existing file. + with card_path.open(mode="r", newline="", encoding="utf-8") as f: + return cls(f.read(), ignore_metadata_errors=ignore_metadata_errors) + + def validate(self, repo_type: Optional[str] = None): + """Validates card against Hugging Face Hub's card validation logic. + Using this function requires access to the internet, so it is only called + internally by [`huggingface_hub.repocard.RepoCard.push_to_hub`]. + + Args: + repo_type (`str`, *optional*, defaults to "model"): + The type of Hugging Face repo to push to. Options are "model", "dataset", and "space". + If this function is called from a child class, the default will be the child class's `repo_type`. + ++ Raises the following errors: + + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if the card fails validation checks. + - [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError) + if the request to the Hub API fails for any other reason. + + + """ + + # If repo type is provided, otherwise, use the repo type of the card. + repo_type = repo_type or self.repo_type + + body = { + "repoType": repo_type, + "content": str(self), + } + headers = {"Accept": "text/plain"} + + try: + r = get_session().post("https://huggingface.co/api/validate-yaml", body, headers=headers) + r.raise_for_status() + except requests.exceptions.HTTPError as exc: + if r.status_code == 400: + raise ValueError(r.text) + else: + raise exc + + def push_to_hub( + self, + repo_id: str, + token: Optional[str] = None, + repo_type: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = None, + parent_commit: Optional[str] = None, + ): + """Push a RepoCard to a Hugging Face Hub repo. + + Args: + repo_id (`str`): + The repo ID of the Hugging Face Hub repo to push to. Example: "nateraw/food". + token (`str`, *optional*): + Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to + the stored token. + repo_type (`str`, *optional*, defaults to "model"): + The type of Hugging Face repo to push to. Options are "model", "dataset", and "space". If this + function is called by a child class, it will default to the child class's `repo_type`. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. + commit_description (`str`, *optional*) + The description of the generated commit. + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the `"main"` branch. + create_pr (`bool`, *optional*): + Whether or not to create a Pull Request with this commit. Defaults to `False`. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + Returns: + `str`: URL of the commit which updated the card metadata. + """ + + # If repo type is provided, otherwise, use the repo type of the card. + repo_type = repo_type or self.repo_type + + # Validate card before pushing to hub + self.validate(repo_type=repo_type) + + with SoftTemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) / constants.REPOCARD_NAME + tmp_path.write_text(str(self)) + url = upload_file( + path_or_fileobj=str(tmp_path), + path_in_repo=constants.REPOCARD_NAME, + repo_id=repo_id, + token=token, + repo_type=repo_type, + commit_message=commit_message, + commit_description=commit_description, + create_pr=create_pr, + revision=revision, + parent_commit=parent_commit, + ) + return url + + @classmethod + def from_template( + cls, + card_data: CardData, + template_path: Optional[str] = None, + template_str: Optional[str] = None, + **template_kwargs, + ): + """Initialize a RepoCard from a template. By default, it uses the default template. + + Templates are Jinja2 templates that can be customized by passing keyword arguments. + + Args: + card_data (`huggingface_hub.CardData`): + A huggingface_hub.CardData instance containing the metadata you want to include in the YAML + header of the repo card on the Hugging Face Hub. + template_path (`str`, *optional*): + A path to a markdown file with optional Jinja template variables that can be filled + in with `template_kwargs`. Defaults to the default template. + + Returns: + [`huggingface_hub.repocard.RepoCard`]: A RepoCard instance with the specified card data and content from the + template. + """ + if is_jinja_available(): + import jinja2 + else: + raise ImportError( + "Using RepoCard.from_template requires Jinja2 to be installed. Please" + " install it with `pip install Jinja2`." + ) + + kwargs = card_data.to_dict().copy() + kwargs.update(template_kwargs) # Template_kwargs have priority + + if template_path is not None: + template_str = Path(template_path).read_text() + if template_str is None: + template_str = Path(cls.default_template_path).read_text() + template = jinja2.Template(template_str) + content = template.render(card_data=card_data.to_yaml(), **kwargs) + return cls(content) + + +class ModelCard(RepoCard): + card_data_class = ModelCardData + default_template_path = TEMPLATE_MODELCARD_PATH + repo_type = "model" + + @classmethod + def from_template( # type: ignore # violates Liskov property but easier to use + cls, + card_data: ModelCardData, + template_path: Optional[str] = None, + template_str: Optional[str] = None, + **template_kwargs, + ): + """Initialize a ModelCard from a template. By default, it uses the default template, which can be found here: + https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md + + Templates are Jinja2 templates that can be customized by passing keyword arguments. + + Args: + card_data (`huggingface_hub.ModelCardData`): + A huggingface_hub.ModelCardData instance containing the metadata you want to include in the YAML + header of the model card on the Hugging Face Hub. + template_path (`str`, *optional*): + A path to a markdown file with optional Jinja template variables that can be filled + in with `template_kwargs`. Defaults to the default template. + + Returns: + [`huggingface_hub.ModelCard`]: A ModelCard instance with the specified card data and content from the + template. + + Example: + ```python + >>> from huggingface_hub import ModelCard, ModelCardData, EvalResult + + >>> # Using the Default Template + >>> card_data = ModelCardData( + ... language='en', + ... license='mit', + ... library_name='timm', + ... tags=['image-classification', 'resnet'], + ... datasets=['beans'], + ... metrics=['accuracy'], + ... ) + >>> card = ModelCard.from_template( + ... card_data, + ... model_description='This model does x + y...' + ... ) + + >>> # Including Evaluation Results + >>> card_data = ModelCardData( + ... language='en', + ... tags=['image-classification', 'resnet'], + ... eval_results=[ + ... EvalResult( + ... task_type='image-classification', + ... dataset_type='beans', + ... dataset_name='Beans', + ... metric_type='accuracy', + ... metric_value=0.9, + ... ), + ... ], + ... model_name='my-cool-model', + ... ) + >>> card = ModelCard.from_template(card_data) + + >>> # Using a Custom Template + >>> card_data = ModelCardData( + ... language='en', + ... tags=['image-classification', 'resnet'] + ... ) + >>> card = ModelCard.from_template( + ... card_data=card_data, + ... template_path='./src/huggingface_hub/templates/modelcard_template.md', + ... custom_template_var='custom value', # will be replaced in template if it exists + ... ) + + ``` + """ + return super().from_template(card_data, template_path, template_str, **template_kwargs) + + +class DatasetCard(RepoCard): + card_data_class = DatasetCardData + default_template_path = TEMPLATE_DATASETCARD_PATH + repo_type = "dataset" + + @classmethod + def from_template( # type: ignore # violates Liskov property but easier to use + cls, + card_data: DatasetCardData, + template_path: Optional[str] = None, + template_str: Optional[str] = None, + **template_kwargs, + ): + """Initialize a DatasetCard from a template. By default, it uses the default template, which can be found here: + https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/datasetcard_template.md + + Templates are Jinja2 templates that can be customized by passing keyword arguments. + + Args: + card_data (`huggingface_hub.DatasetCardData`): + A huggingface_hub.DatasetCardData instance containing the metadata you want to include in the YAML + header of the dataset card on the Hugging Face Hub. + template_path (`str`, *optional*): + A path to a markdown file with optional Jinja template variables that can be filled + in with `template_kwargs`. Defaults to the default template. + + Returns: + [`huggingface_hub.DatasetCard`]: A DatasetCard instance with the specified card data and content from the + template. + + Example: + ```python + >>> from huggingface_hub import DatasetCard, DatasetCardData + + >>> # Using the Default Template + >>> card_data = DatasetCardData( + ... language='en', + ... license='mit', + ... annotations_creators='crowdsourced', + ... task_categories=['text-classification'], + ... task_ids=['sentiment-classification', 'text-scoring'], + ... multilinguality='monolingual', + ... pretty_name='My Text Classification Dataset', + ... ) + >>> card = DatasetCard.from_template( + ... card_data, + ... pretty_name=card_data.pretty_name, + ... ) + + >>> # Using a Custom Template + >>> card_data = DatasetCardData( + ... language='en', + ... license='mit', + ... ) + >>> card = DatasetCard.from_template( + ... card_data=card_data, + ... template_path='./src/huggingface_hub/templates/datasetcard_template.md', + ... custom_template_var='custom value', # will be replaced in template if it exists + ... ) + + ``` + """ + return super().from_template(card_data, template_path, template_str, **template_kwargs) + + +class SpaceCard(RepoCard): + card_data_class = SpaceCardData + default_template_path = TEMPLATE_MODELCARD_PATH + repo_type = "space" + + +def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]: # noqa: F722 + """Detect the line ending of a string. Used by RepoCard to avoid making huge diff on newlines. + + Uses same implementation as in Hub server, keep it in sync. + + Returns: + str: The detected line ending of the string. + """ + cr = content.count("\r") + lf = content.count("\n") + crlf = content.count("\r\n") + if cr + lf == 0: + return None + if crlf == cr and crlf == lf: + return "\r\n" + if cr > lf: + return "\r" + else: + return "\n" + + +def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]: + content = Path(local_path).read_text() + match = REGEX_YAML_BLOCK.search(content) + if match: + yaml_block = match.group(2) + data = yaml.safe_load(yaml_block) + if data is None or isinstance(data, dict): + return data + raise ValueError("repo card metadata block should be a dict") + else: + return None + + +def metadata_save(local_path: Union[str, Path], data: Dict) -> None: + """ + Save the metadata dict in the upper YAML part Trying to preserve newlines as + in the existing file. Docs about open() with newline="" parameter: + https://docs.python.org/3/library/functions.html?highlight=open#open Does + not work with "^M" linebreaks, which are replaced by \n + """ + line_break = "\n" + content = "" + # try to detect existing newline character + if os.path.exists(local_path): + with open(local_path, "r", newline="", encoding="utf8") as readme: + content = readme.read() + if isinstance(readme.newlines, tuple): + line_break = readme.newlines[0] + elif isinstance(readme.newlines, str): + line_break = readme.newlines + + # creates a new file if it not + with open(local_path, "w", newline="", encoding="utf8") as readme: + data_yaml = yaml_dump(data, sort_keys=False, line_break=line_break) + # sort_keys: keep dict order + match = REGEX_YAML_BLOCK.search(content) + if match: + output = content[: match.start()] + f"---{line_break}{data_yaml}---{line_break}" + content[match.end() :] + else: + output = f"---{line_break}{data_yaml}---{line_break}{content}" + + readme.write(output) + readme.close() + + +def metadata_eval_result( + *, + model_pretty_name: str, + task_pretty_name: str, + task_id: str, + metrics_pretty_name: str, + metrics_id: str, + metrics_value: Any, + dataset_pretty_name: str, + dataset_id: str, + metrics_config: Optional[str] = None, + metrics_verified: bool = False, + dataset_config: Optional[str] = None, + dataset_split: Optional[str] = None, + dataset_revision: Optional[str] = None, + metrics_verification_token: Optional[str] = None, +) -> Dict: + """ + Creates a metadata dict with the result from a model evaluated on a dataset. + + Args: + model_pretty_name (`str`): + The name of the model in natural language. + task_pretty_name (`str`): + The name of a task in natural language. + task_id (`str`): + Example: automatic-speech-recognition. A task id. + metrics_pretty_name (`str`): + A name for the metric in natural language. Example: Test WER. + metrics_id (`str`): + Example: wer. A metric id from https://hf.co/metrics. + metrics_value (`Any`): + The value from the metric. Example: 20.0 or "20.0 Β± 1.2". + dataset_pretty_name (`str`): + The name of the dataset in natural language. + dataset_id (`str`): + Example: common_voice. A dataset id from https://hf.co/datasets. + metrics_config (`str`, *optional*): + The name of the metric configuration used in `load_metric()`. + Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. + metrics_verified (`bool`, *optional*, defaults to `False`): + Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. + dataset_config (`str`, *optional*): + Example: fr. The name of the dataset configuration used in `load_dataset()`. + dataset_split (`str`, *optional*): + Example: test. The name of the dataset split used in `load_dataset()`. + dataset_revision (`str`, *optional*): + Example: 5503434ddd753f426f4b38109466949a1217c2bb. The name of the dataset dataset revision + used in `load_dataset()`. + metrics_verification_token (`bool`, *optional*): + A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. + + Returns: + `dict`: a metadata dict with the result from a model evaluated on a dataset. + + Example: + ```python + >>> from huggingface_hub import metadata_eval_result + >>> results = metadata_eval_result( + ... model_pretty_name="RoBERTa fine-tuned on ReactionGIF", + ... task_pretty_name="Text Classification", + ... task_id="text-classification", + ... metrics_pretty_name="Accuracy", + ... metrics_id="accuracy", + ... metrics_value=0.2662102282047272, + ... dataset_pretty_name="ReactionJPEG", + ... dataset_id="julien-c/reactionjpeg", + ... dataset_config="default", + ... dataset_split="test", + ... ) + >>> results == { + ... 'model-index': [ + ... { + ... 'name': 'RoBERTa fine-tuned on ReactionGIF', + ... 'results': [ + ... { + ... 'task': { + ... 'type': 'text-classification', + ... 'name': 'Text Classification' + ... }, + ... 'dataset': { + ... 'name': 'ReactionJPEG', + ... 'type': 'julien-c/reactionjpeg', + ... 'config': 'default', + ... 'split': 'test' + ... }, + ... 'metrics': [ + ... { + ... 'type': 'accuracy', + ... 'value': 0.2662102282047272, + ... 'name': 'Accuracy', + ... 'verified': False + ... } + ... ] + ... } + ... ] + ... } + ... ] + ... } + True + + ``` + """ + + return { + "model-index": eval_results_to_model_index( + model_name=model_pretty_name, + eval_results=[ + EvalResult( + task_name=task_pretty_name, + task_type=task_id, + metric_name=metrics_pretty_name, + metric_type=metrics_id, + metric_value=metrics_value, + dataset_name=dataset_pretty_name, + dataset_type=dataset_id, + metric_config=metrics_config, + verified=metrics_verified, + verify_token=metrics_verification_token, + dataset_config=dataset_config, + dataset_split=dataset_split, + dataset_revision=dataset_revision, + ) + ], + ) + } + + +@validate_hf_hub_args +def metadata_update( + repo_id: str, + metadata: Dict, + *, + repo_type: Optional[str] = None, + overwrite: bool = False, + token: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + revision: Optional[str] = None, + create_pr: bool = False, + parent_commit: Optional[str] = None, +) -> str: + """ + Updates the metadata in the README.md of a repository on the Hugging Face Hub. + If the README.md file doesn't exist yet, a new one is created with metadata and an + the default ModelCard or DatasetCard template. For `space` repo, an error is thrown + as a Space cannot exist without a `README.md` file. + + Args: + repo_id (`str`): + The name of the repository. + metadata (`dict`): + A dictionary containing the metadata to be updated. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if updating to a dataset or space, + `None` or `"model"` if updating to a model. Default is `None`. + overwrite (`bool`, *optional*, defaults to `False`): + If set to `True` an existing field can be overwritten, otherwise + attempting to overwrite an existing field will cause an error. + token (`str`, *optional*): + The Hugging Face authentication token. + commit_message (`str`, *optional*): + The summary / title / first line of the generated commit. Defaults to + `f"Update metadata with huggingface_hub"` + commit_description (`str` *optional*) + The description of the generated commit + revision (`str`, *optional*): + The git revision to commit from. Defaults to the head of the + `"main"` branch. + create_pr (`boolean`, *optional*): + Whether or not to create a Pull Request from `revision` with that commit. + Defaults to `False`. + parent_commit (`str`, *optional*): + The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported. + If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`. + If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`. + Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be + especially useful if the repo is updated / committed to concurrently. + Returns: + `str`: URL of the commit which updated the card metadata. + + Example: + ```python + >>> from huggingface_hub import metadata_update + >>> metadata = {'model-index': [{'name': 'RoBERTa fine-tuned on ReactionGIF', + ... 'results': [{'dataset': {'name': 'ReactionGIF', + ... 'type': 'julien-c/reactiongif'}, + ... 'metrics': [{'name': 'Recall', + ... 'type': 'recall', + ... 'value': 0.7762102282047272}], + ... 'task': {'name': 'Text Classification', + ... 'type': 'text-classification'}}]}]} + >>> url = metadata_update("hf-internal-testing/reactiongif-roberta-card", metadata) + + ``` + """ + commit_message = commit_message if commit_message is not None else "Update metadata with huggingface_hub" + + # Card class given repo_type + card_class: Type[RepoCard] + if repo_type is None or repo_type == "model": + card_class = ModelCard + elif repo_type == "dataset": + card_class = DatasetCard + elif repo_type == "space": + card_class = RepoCard + else: + raise ValueError(f"Unknown repo_type: {repo_type}") + + # Either load repo_card from the Hub or create an empty one. + # NOTE: Will not create the repo if it doesn't exist. + try: + card = card_class.load(repo_id, token=token, repo_type=repo_type) + except EntryNotFoundError: + if repo_type == "space": + raise ValueError("Cannot update metadata on a Space that doesn't contain a `README.md` file.") + + # Initialize a ModelCard or DatasetCard from default template and no data. + card = card_class.from_template(CardData()) + + for key, value in metadata.items(): + if key == "model-index": + # if the new metadata doesn't include a name, either use existing one or repo name + if "name" not in value[0]: + value[0]["name"] = getattr(card, "model_name", repo_id) + model_name, new_results = model_index_to_eval_results(value) + if card.data.eval_results is None: + card.data.eval_results = new_results + card.data.model_name = model_name + else: + existing_results = card.data.eval_results + + # Iterate over new results + # Iterate over existing results + # If both results describe the same metric but value is different: + # If overwrite=True: overwrite the metric value + # Else: raise ValueError + # Else: append new result to existing ones. + for new_result in new_results: + result_found = False + for existing_result in existing_results: + if new_result.is_equal_except_value(existing_result): + if new_result != existing_result and not overwrite: + raise ValueError( + "You passed a new value for the existing metric" + f" 'name: {new_result.metric_name}, type: " + f"{new_result.metric_type}'. Set `overwrite=True`" + " to overwrite existing metrics." + ) + result_found = True + existing_result.metric_value = new_result.metric_value + if existing_result.verified is True: + existing_result.verify_token = new_result.verify_token + if not result_found: + card.data.eval_results.append(new_result) + else: + # Any metadata that is not a result metric + if card.data.get(key) is not None and not overwrite and card.data.get(key) != value: + raise ValueError( + f"You passed a new value for the existing meta data field '{key}'." + " Set `overwrite=True` to overwrite existing metadata." + ) + else: + card.data[key] = value + + return card.push_to_hub( + repo_id, + token=token, + repo_type=repo_type, + commit_message=commit_message, + commit_description=commit_description, + create_pr=create_pr, + revision=revision, + parent_commit=parent_commit, + ) diff --git a/huggingface_hub/repocard_data.py b/huggingface_hub/repocard_data.py new file mode 100644 index 0000000000000000000000000000000000000000..b9b93aac70dbad460f71fc403a3898e32440bf03 --- /dev/null +++ b/huggingface_hub/repocard_data.py @@ -0,0 +1,741 @@ +import copy +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +from huggingface_hub.utils import logging, yaml_dump + + +logger = logging.get_logger(__name__) + + +@dataclass +class EvalResult: + """ + Flattened representation of individual evaluation results found in model-index of Model Cards. + + For more information on the model-index spec, see https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1. + + Args: + task_type (`str`): + The task identifier. Example: "image-classification". + dataset_type (`str`): + The dataset identifier. Example: "common_voice". Use dataset id from https://hf.co/datasets. + dataset_name (`str`): + A pretty name for the dataset. Example: "Common Voice (French)". + metric_type (`str`): + The metric identifier. Example: "wer". Use metric id from https://hf.co/metrics. + metric_value (`Any`): + The metric value. Example: 0.9 or "20.0 Β± 1.2". + task_name (`str`, *optional*): + A pretty name for the task. Example: "Speech Recognition". + dataset_config (`str`, *optional*): + The name of the dataset configuration used in `load_dataset()`. + Example: fr in `load_dataset("common_voice", "fr")`. See the `datasets` docs for more info: + https://hf.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name + dataset_split (`str`, *optional*): + The split used in `load_dataset()`. Example: "test". + dataset_revision (`str`, *optional*): + The revision (AKA Git Sha) of the dataset used in `load_dataset()`. + Example: 5503434ddd753f426f4b38109466949a1217c2bb + dataset_args (`Dict[str, Any]`, *optional*): + The arguments passed during `Metric.compute()`. Example for `bleu`: `{"max_order": 4}` + metric_name (`str`, *optional*): + A pretty name for the metric. Example: "Test WER". + metric_config (`str`, *optional*): + The name of the metric configuration used in `load_metric()`. + Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. + See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations + metric_args (`Dict[str, Any]`, *optional*): + The arguments passed during `Metric.compute()`. Example for `bleu`: max_order: 4 + verified (`bool`, *optional*): + Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. + verify_token (`str`, *optional*): + A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. + source_name (`str`, *optional*): + The name of the source of the evaluation result. Example: "Open LLM Leaderboard". + source_url (`str`, *optional*): + The URL of the source of the evaluation result. Example: "https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard". + """ + + # Required + + # The task identifier + # Example: automatic-speech-recognition + task_type: str + + # The dataset identifier + # Example: common_voice. Use dataset id from https://hf.co/datasets + dataset_type: str + + # A pretty name for the dataset. + # Example: Common Voice (French) + dataset_name: str + + # The metric identifier + # Example: wer. Use metric id from https://hf.co/metrics + metric_type: str + + # Value of the metric. + # Example: 20.0 or "20.0 Β± 1.2" + metric_value: Any + + # Optional + + # A pretty name for the task. + # Example: Speech Recognition + task_name: Optional[str] = None + + # The name of the dataset configuration used in `load_dataset()`. + # Example: fr in `load_dataset("common_voice", "fr")`. + # See the `datasets` docs for more info: + # https://huggingface.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name + dataset_config: Optional[str] = None + + # The split used in `load_dataset()`. + # Example: test + dataset_split: Optional[str] = None + + # The revision (AKA Git Sha) of the dataset used in `load_dataset()`. + # Example: 5503434ddd753f426f4b38109466949a1217c2bb + dataset_revision: Optional[str] = None + + # The arguments passed during `Metric.compute()`. + # Example for `bleu`: max_order: 4 + dataset_args: Optional[Dict[str, Any]] = None + + # A pretty name for the metric. + # Example: Test WER + metric_name: Optional[str] = None + + # The name of the metric configuration used in `load_metric()`. + # Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. + # See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations + metric_config: Optional[str] = None + + # The arguments passed during `Metric.compute()`. + # Example for `bleu`: max_order: 4 + metric_args: Optional[Dict[str, Any]] = None + + # Indicates whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. Automatically computed by Hugging Face, do not set. + verified: Optional[bool] = None + + # A JSON Web Token that is used to verify whether the metrics originate from Hugging Face's [evaluation service](https://huggingface.co/spaces/autoevaluate/model-evaluator) or not. + verify_token: Optional[str] = None + + # The name of the source of the evaluation result. + # Example: Open LLM Leaderboard + source_name: Optional[str] = None + + # The URL of the source of the evaluation result. + # Example: https://huggingface.co/spaces/open-llm-leaderboard/open_llm_leaderboard + source_url: Optional[str] = None + + @property + def unique_identifier(self) -> tuple: + """Returns a tuple that uniquely identifies this evaluation.""" + return ( + self.task_type, + self.dataset_type, + self.dataset_config, + self.dataset_split, + self.dataset_revision, + ) + + def is_equal_except_value(self, other: "EvalResult") -> bool: + """ + Return True if `self` and `other` describe exactly the same metric but with a + different value. + """ + for key, _ in self.__dict__.items(): + if key == "metric_value": + continue + # For metrics computed by Hugging Face's evaluation service, `verify_token` is derived from `metric_value`, + # so we exclude it here in the comparison. + if key != "verify_token" and getattr(self, key) != getattr(other, key): + return False + return True + + def __post_init__(self) -> None: + if self.source_name is not None and self.source_url is None: + raise ValueError("If `source_name` is provided, `source_url` must also be provided.") + + +@dataclass +class CardData: + """Structure containing metadata from a RepoCard. + + [`CardData`] is the parent class of [`ModelCardData`] and [`DatasetCardData`]. + + Metadata can be exported as a dictionary or YAML. Export can be customized to alter the representation of the data + (example: flatten evaluation results). `CardData` behaves as a dictionary (can get, pop, set values) but do not + inherit from `dict` to allow this export step. + """ + + def __init__(self, ignore_metadata_errors: bool = False, **kwargs): + self.__dict__.update(kwargs) + + def to_dict(self) -> Dict[str, Any]: + """Converts CardData to a dict. + + Returns: + `dict`: CardData represented as a dictionary ready to be dumped to a YAML + block for inclusion in a README.md file. + """ + + data_dict = copy.deepcopy(self.__dict__) + self._to_dict(data_dict) + return _remove_none(data_dict) + + def _to_dict(self, data_dict): + """Use this method in child classes to alter the dict representation of the data. Alter the dict in-place. + + Args: + data_dict (`dict`): The raw dict representation of the card data. + """ + pass + + def to_yaml(self, line_break=None) -> str: + """Dumps CardData to a YAML block for inclusion in a README.md file. + + Args: + line_break (str, *optional*): + The line break to use when dumping to yaml. + + Returns: + `str`: CardData represented as a YAML block. + """ + return yaml_dump(self.to_dict(), sort_keys=False, line_break=line_break).strip() + + def __repr__(self): + return repr(self.__dict__) + + def __str__(self): + return self.to_yaml() + + def get(self, key: str, default: Any = None) -> Any: + """Get value for a given metadata key.""" + return self.__dict__.get(key, default) + + def pop(self, key: str, default: Any = None) -> Any: + """Pop value for a given metadata key.""" + return self.__dict__.pop(key, default) + + def __getitem__(self, key: str) -> Any: + """Get value for a given metadata key.""" + return self.__dict__[key] + + def __setitem__(self, key: str, value: Any) -> None: + """Set value for a given metadata key.""" + self.__dict__[key] = value + + def __contains__(self, key: str) -> bool: + """Check if a given metadata key is set.""" + return key in self.__dict__ + + def __len__(self) -> int: + """Return the number of metadata keys set.""" + return len(self.__dict__) + + +class ModelCardData(CardData): + """Model Card Metadata that is used by Hugging Face Hub when included at the top of your README.md + + Args: + base_model (`str` or `List[str]`, *optional*): + The identifier of the base model from which the model derives. This is applicable for example if your model is a + fine-tune or adapter of an existing model. The value must be the ID of a model on the Hub (or a list of IDs + if your model derives from multiple models). Defaults to None. + datasets (`List[str]`, *optional*): + List of datasets that were used to train this model. Should be a dataset ID + found on https://hf.co/datasets. Defaults to None. + eval_results (`Union[List[EvalResult], EvalResult]`, *optional*): + List of `huggingface_hub.EvalResult` that define evaluation results of the model. If provided, + `model_name` is used to as a name on PapersWithCode's leaderboards. Defaults to `None`. + language (`Union[str, List[str]]`, *optional*): + Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or + 639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`. + library_name (`str`, *optional*): + Name of library used by this model. Example: keras or any library from + https://github.com/huggingface/huggingface.js/blob/main/packages/tasks/src/model-libraries.ts. + Defaults to None. + license (`str`, *optional*): + License of this model. Example: apache-2.0 or any license from + https://huggingface.co/docs/hub/repositories-licenses. Defaults to None. + license_name (`str`, *optional*): + Name of the license of this model. Defaults to None. To be used in conjunction with `license_link`. + Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a name. In that case, use `license` instead. + license_link (`str`, *optional*): + Link to the license of this model. Defaults to None. To be used in conjunction with `license_name`. + Common licenses (Apache-2.0, MIT, CC-BY-SA-4.0) do not need a link. In that case, use `license` instead. + metrics (`List[str]`, *optional*): + List of metrics used to evaluate this model. Should be a metric name that can be found + at https://hf.co/metrics. Example: 'accuracy'. Defaults to None. + model_name (`str`, *optional*): + A name for this model. It is used along with + `eval_results` to construct the `model-index` within the card's metadata. The name + you supply here is what will be used on PapersWithCode's leaderboards. If None is provided + then the repo name is used as a default. Defaults to None. + tags (`List[str]`, *optional*): + List of tags to add to your model that can be used when filtering on the Hugging + Face Hub. Defaults to None. + ignore_metadata_errors (`str`): + If True, errors while parsing the metadata section will be ignored. Some information might be lost during + the process. Use it at your own risk. + kwargs (`dict`, *optional*): + Additional metadata that will be added to the model card. Defaults to None. + + Example: + ```python + >>> from huggingface_hub import ModelCardData + >>> card_data = ModelCardData( + ... language="en", + ... license="mit", + ... library_name="timm", + ... tags=['image-classification', 'resnet'], + ... ) + >>> card_data.to_dict() + {'language': 'en', 'license': 'mit', 'library_name': 'timm', 'tags': ['image-classification', 'resnet']} + + ``` + """ + + def __init__( + self, + *, + base_model: Optional[Union[str, List[str]]] = None, + datasets: Optional[List[str]] = None, + eval_results: Optional[List[EvalResult]] = None, + language: Optional[Union[str, List[str]]] = None, + library_name: Optional[str] = None, + license: Optional[str] = None, + license_name: Optional[str] = None, + license_link: Optional[str] = None, + metrics: Optional[List[str]] = None, + model_name: Optional[str] = None, + pipeline_tag: Optional[str] = None, + tags: Optional[List[str]] = None, + ignore_metadata_errors: bool = False, + **kwargs, + ): + self.base_model = base_model + self.datasets = datasets + self.eval_results = eval_results + self.language = language + self.library_name = library_name + self.license = license + self.license_name = license_name + self.license_link = license_link + self.metrics = metrics + self.model_name = model_name + self.pipeline_tag = pipeline_tag + self.tags = _to_unique_list(tags) + + model_index = kwargs.pop("model-index", None) + if model_index: + try: + model_name, eval_results = model_index_to_eval_results(model_index) + self.model_name = model_name + self.eval_results = eval_results + except (KeyError, TypeError) as error: + if ignore_metadata_errors: + logger.warning("Invalid model-index. Not loading eval results into CardData.") + else: + raise ValueError( + f"Invalid `model_index` in metadata cannot be parsed: {error.__class__} {error}. Pass" + " `ignore_metadata_errors=True` to ignore this error while loading a Model Card. Warning:" + " some information will be lost. Use it at your own risk." + ) + + super().__init__(**kwargs) + + if self.eval_results: + if isinstance(self.eval_results, EvalResult): + self.eval_results = [self.eval_results] + if self.model_name is None: + raise ValueError("Passing `eval_results` requires `model_name` to be set.") + + def _to_dict(self, data_dict): + """Format the internal data dict. In this case, we convert eval results to a valid model index""" + if self.eval_results is not None: + data_dict["model-index"] = eval_results_to_model_index(self.model_name, self.eval_results) + del data_dict["eval_results"], data_dict["model_name"] + + +class DatasetCardData(CardData): + """Dataset Card Metadata that is used by Hugging Face Hub when included at the top of your README.md + + Args: + language (`List[str]`, *optional*): + Language of dataset's data or metadata. It must be an ISO 639-1, 639-2 or + 639-3 code (two/three letters), or a special value like "code", "multilingual". + license (`Union[str, List[str]]`, *optional*): + License(s) of this dataset. Example: apache-2.0 or any license from + https://huggingface.co/docs/hub/repositories-licenses. + annotations_creators (`Union[str, List[str]]`, *optional*): + How the annotations for the dataset were created. + Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'no-annotation', 'other'. + language_creators (`Union[str, List[str]]`, *optional*): + How the text-based data in the dataset was created. + Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'other' + multilinguality (`Union[str, List[str]]`, *optional*): + Whether the dataset is multilingual. + Options are: 'monolingual', 'multilingual', 'translation', 'other'. + size_categories (`Union[str, List[str]]`, *optional*): + The number of examples in the dataset. Options are: 'n<1K', '1K1T', and 'other'. + source_datasets (`List[str]]`, *optional*): + Indicates whether the dataset is an original dataset or extended from another existing dataset. + Options are: 'original' and 'extended'. + task_categories (`Union[str, List[str]]`, *optional*): + What categories of task does the dataset support? + task_ids (`Union[str, List[str]]`, *optional*): + What specific tasks does the dataset support? + paperswithcode_id (`str`, *optional*): + ID of the dataset on PapersWithCode. + pretty_name (`str`, *optional*): + A more human-readable name for the dataset. (ex. "Cats vs. Dogs") + train_eval_index (`Dict`, *optional*): + A dictionary that describes the necessary spec for doing evaluation on the Hub. + If not provided, it will be gathered from the 'train-eval-index' key of the kwargs. + config_names (`Union[str, List[str]]`, *optional*): + A list of the available dataset configs for the dataset. + """ + + def __init__( + self, + *, + language: Optional[Union[str, List[str]]] = None, + license: Optional[Union[str, List[str]]] = None, + annotations_creators: Optional[Union[str, List[str]]] = None, + language_creators: Optional[Union[str, List[str]]] = None, + multilinguality: Optional[Union[str, List[str]]] = None, + size_categories: Optional[Union[str, List[str]]] = None, + source_datasets: Optional[List[str]] = None, + task_categories: Optional[Union[str, List[str]]] = None, + task_ids: Optional[Union[str, List[str]]] = None, + paperswithcode_id: Optional[str] = None, + pretty_name: Optional[str] = None, + train_eval_index: Optional[Dict] = None, + config_names: Optional[Union[str, List[str]]] = None, + ignore_metadata_errors: bool = False, + **kwargs, + ): + self.annotations_creators = annotations_creators + self.language_creators = language_creators + self.language = language + self.license = license + self.multilinguality = multilinguality + self.size_categories = size_categories + self.source_datasets = source_datasets + self.task_categories = task_categories + self.task_ids = task_ids + self.paperswithcode_id = paperswithcode_id + self.pretty_name = pretty_name + self.config_names = config_names + + # TODO - maybe handle this similarly to EvalResult? + self.train_eval_index = train_eval_index or kwargs.pop("train-eval-index", None) + super().__init__(**kwargs) + + def _to_dict(self, data_dict): + data_dict["train-eval-index"] = data_dict.pop("train_eval_index") + + +class SpaceCardData(CardData): + """Space Card Metadata that is used by Hugging Face Hub when included at the top of your README.md + + To get an exhaustive reference of Spaces configuration, please visit https://huggingface.co/docs/hub/spaces-config-reference#spaces-configuration-reference. + + Args: + title (`str`, *optional*) + Title of the Space. + sdk (`str`, *optional*) + SDK of the Space (one of `gradio`, `streamlit`, `docker`, or `static`). + sdk_version (`str`, *optional*) + Version of the used SDK (if Gradio/Streamlit sdk). + python_version (`str`, *optional*) + Python version used in the Space (if Gradio/Streamlit sdk). + app_file (`str`, *optional*) + Path to your main application file (which contains either gradio or streamlit Python code, or static html code). + Path is relative to the root of the repository. + app_port (`str`, *optional*) + Port on which your application is running. Used only if sdk is `docker`. + license (`str`, *optional*) + License of this model. Example: apache-2.0 or any license from + https://huggingface.co/docs/hub/repositories-licenses. + duplicated_from (`str`, *optional*) + ID of the original Space if this is a duplicated Space. + models (List[`str`], *optional*) + List of models related to this Space. Should be a dataset ID found on https://hf.co/models. + datasets (`List[str]`, *optional*) + List of datasets related to this Space. Should be a dataset ID found on https://hf.co/datasets. + tags (`List[str]`, *optional*) + List of tags to add to your Space that can be used when filtering on the Hub. + ignore_metadata_errors (`str`): + If True, errors while parsing the metadata section will be ignored. Some information might be lost during + the process. Use it at your own risk. + kwargs (`dict`, *optional*): + Additional metadata that will be added to the space card. + + Example: + ```python + >>> from huggingface_hub import SpaceCardData + >>> card_data = SpaceCardData( + ... title="Dreambooth Training", + ... license="mit", + ... sdk="gradio", + ... duplicated_from="multimodalart/dreambooth-training" + ... ) + >>> card_data.to_dict() + {'title': 'Dreambooth Training', 'sdk': 'gradio', 'license': 'mit', 'duplicated_from': 'multimodalart/dreambooth-training'} + ``` + """ + + def __init__( + self, + *, + title: Optional[str] = None, + sdk: Optional[str] = None, + sdk_version: Optional[str] = None, + python_version: Optional[str] = None, + app_file: Optional[str] = None, + app_port: Optional[int] = None, + license: Optional[str] = None, + duplicated_from: Optional[str] = None, + models: Optional[List[str]] = None, + datasets: Optional[List[str]] = None, + tags: Optional[List[str]] = None, + ignore_metadata_errors: bool = False, + **kwargs, + ): + self.title = title + self.sdk = sdk + self.sdk_version = sdk_version + self.python_version = python_version + self.app_file = app_file + self.app_port = app_port + self.license = license + self.duplicated_from = duplicated_from + self.models = models + self.datasets = datasets + self.tags = _to_unique_list(tags) + super().__init__(**kwargs) + + +def model_index_to_eval_results(model_index: List[Dict[str, Any]]) -> Tuple[str, List[EvalResult]]: + """Takes in a model index and returns the model name and a list of `huggingface_hub.EvalResult` objects. + + A detailed spec of the model index can be found here: + https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 + + Args: + model_index (`List[Dict[str, Any]]`): + A model index data structure, likely coming from a README.md file on the + Hugging Face Hub. + + Returns: + model_name (`str`): + The name of the model as found in the model index. This is used as the + identifier for the model on leaderboards like PapersWithCode. + eval_results (`List[EvalResult]`): + A list of `huggingface_hub.EvalResult` objects containing the metrics + reported in the provided model_index. + + Example: + ```python + >>> from huggingface_hub.repocard_data import model_index_to_eval_results + >>> # Define a minimal model index + >>> model_index = [ + ... { + ... "name": "my-cool-model", + ... "results": [ + ... { + ... "task": { + ... "type": "image-classification" + ... }, + ... "dataset": { + ... "type": "beans", + ... "name": "Beans" + ... }, + ... "metrics": [ + ... { + ... "type": "accuracy", + ... "value": 0.9 + ... } + ... ] + ... } + ... ] + ... } + ... ] + >>> model_name, eval_results = model_index_to_eval_results(model_index) + >>> model_name + 'my-cool-model' + >>> eval_results[0].task_type + 'image-classification' + >>> eval_results[0].metric_type + 'accuracy' + + ``` + """ + + eval_results = [] + for elem in model_index: + name = elem["name"] + results = elem["results"] + for result in results: + task_type = result["task"]["type"] + task_name = result["task"].get("name") + dataset_type = result["dataset"]["type"] + dataset_name = result["dataset"]["name"] + dataset_config = result["dataset"].get("config") + dataset_split = result["dataset"].get("split") + dataset_revision = result["dataset"].get("revision") + dataset_args = result["dataset"].get("args") + source_name = result.get("source", {}).get("name") + source_url = result.get("source", {}).get("url") + + for metric in result["metrics"]: + metric_type = metric["type"] + metric_value = metric["value"] + metric_name = metric.get("name") + metric_args = metric.get("args") + metric_config = metric.get("config") + verified = metric.get("verified") + verify_token = metric.get("verifyToken") + + eval_result = EvalResult( + task_type=task_type, # Required + dataset_type=dataset_type, # Required + dataset_name=dataset_name, # Required + metric_type=metric_type, # Required + metric_value=metric_value, # Required + task_name=task_name, + dataset_config=dataset_config, + dataset_split=dataset_split, + dataset_revision=dataset_revision, + dataset_args=dataset_args, + metric_name=metric_name, + metric_args=metric_args, + metric_config=metric_config, + verified=verified, + verify_token=verify_token, + source_name=source_name, + source_url=source_url, + ) + eval_results.append(eval_result) + return name, eval_results + + +def _remove_none(obj): + """ + Recursively remove `None` values from a dict. Borrowed from: https://stackoverflow.com/a/20558778 + """ + if isinstance(obj, (list, tuple, set)): + return type(obj)(_remove_none(x) for x in obj if x is not None) + elif isinstance(obj, dict): + return type(obj)((_remove_none(k), _remove_none(v)) for k, v in obj.items() if k is not None and v is not None) + else: + return obj + + +def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult]) -> List[Dict[str, Any]]: + """Takes in given model name and list of `huggingface_hub.EvalResult` and returns a + valid model-index that will be compatible with the format expected by the + Hugging Face Hub. + + Args: + model_name (`str`): + Name of the model (ex. "my-cool-model"). This is used as the identifier + for the model on leaderboards like PapersWithCode. + eval_results (`List[EvalResult]`): + List of `huggingface_hub.EvalResult` objects containing the metrics to be + reported in the model-index. + + Returns: + model_index (`List[Dict[str, Any]]`): The eval_results converted to a model-index. + + Example: + ```python + >>> from huggingface_hub.repocard_data import eval_results_to_model_index, EvalResult + >>> # Define minimal eval_results + >>> eval_results = [ + ... EvalResult( + ... task_type="image-classification", # Required + ... dataset_type="beans", # Required + ... dataset_name="Beans", # Required + ... metric_type="accuracy", # Required + ... metric_value=0.9, # Required + ... ) + ... ] + >>> eval_results_to_model_index("my-cool-model", eval_results) + [{'name': 'my-cool-model', 'results': [{'task': {'type': 'image-classification'}, 'dataset': {'name': 'Beans', 'type': 'beans'}, 'metrics': [{'type': 'accuracy', 'value': 0.9}]}]}] + + ``` + """ + + # Metrics are reported on a unique task-and-dataset basis. + # Here, we make a map of those pairs and the associated EvalResults. + task_and_ds_types_map: Dict[Any, List[EvalResult]] = defaultdict(list) + for eval_result in eval_results: + task_and_ds_types_map[eval_result.unique_identifier].append(eval_result) + + # Use the map from above to generate the model index data. + model_index_data = [] + for results in task_and_ds_types_map.values(): + # All items from `results` share same metadata + sample_result = results[0] + data = { + "task": { + "type": sample_result.task_type, + "name": sample_result.task_name, + }, + "dataset": { + "name": sample_result.dataset_name, + "type": sample_result.dataset_type, + "config": sample_result.dataset_config, + "split": sample_result.dataset_split, + "revision": sample_result.dataset_revision, + "args": sample_result.dataset_args, + }, + "metrics": [ + { + "type": result.metric_type, + "value": result.metric_value, + "name": result.metric_name, + "config": result.metric_config, + "args": result.metric_args, + "verified": result.verified, + "verifyToken": result.verify_token, + } + for result in results + ], + } + if sample_result.source_url is not None: + source = { + "url": sample_result.source_url, + } + if sample_result.source_name is not None: + source["name"] = sample_result.source_name + data["source"] = source + model_index_data.append(data) + + # TODO - Check if there cases where this list is longer than one? + # Finally, the model index itself is list of dicts. + model_index = [ + { + "name": model_name, + "results": model_index_data, + } + ] + return _remove_none(model_index) + + +def _to_unique_list(tags: Optional[List[str]]) -> Optional[List[str]]: + if tags is None: + return tags + unique_tags = [] # make tags unique + keep order explicitly + for tag in tags: + if tag not in unique_tags: + unique_tags.append(tag) + return unique_tags diff --git a/huggingface_hub/repository.py b/huggingface_hub/repository.py new file mode 100644 index 0000000000000000000000000000000000000000..af1ab72fb458340f3fc211f0c5ef577b6471fda1 --- /dev/null +++ b/huggingface_hub/repository.py @@ -0,0 +1,1477 @@ +import atexit +import os +import re +import subprocess +import threading +import time +from contextlib import contextmanager +from pathlib import Path +from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypedDict, Union +from urllib.parse import urlparse + +from huggingface_hub import constants +from huggingface_hub.repocard import metadata_load, metadata_save + +from .hf_api import HfApi, repo_type_and_id_from_hf_id +from .lfs import LFS_MULTIPART_UPLOAD_COMMAND +from .utils import ( + SoftTemporaryDirectory, + get_token, + logging, + run_subprocess, + tqdm, + validate_hf_hub_args, +) +from .utils._deprecation import _deprecate_method + + +logger = logging.get_logger(__name__) + + +class CommandInProgress: + """ + Utility to follow commands launched asynchronously. + """ + + def __init__( + self, + title: str, + is_done_method: Callable, + status_method: Callable, + process: subprocess.Popen, + post_method: Optional[Callable] = None, + ): + self.title = title + self._is_done = is_done_method + self._status = status_method + self._process = process + self._stderr = "" + self._stdout = "" + self._post_method = post_method + + @property + def is_done(self) -> bool: + """ + Whether the process is done. + """ + result = self._is_done() + + if result and self._post_method is not None: + self._post_method() + self._post_method = None + + return result + + @property + def status(self) -> int: + """ + The exit code/status of the current action. Will return `0` if the + command has completed successfully, and a number between 1 and 255 if + the process errored-out. + + Will return -1 if the command is still ongoing. + """ + return self._status() + + @property + def failed(self) -> bool: + """ + Whether the process errored-out. + """ + return self.status > 0 + + @property + def stderr(self) -> str: + """ + The current output message on the standard error. + """ + if self._process.stderr is not None: + self._stderr += self._process.stderr.read() + return self._stderr + + @property + def stdout(self) -> str: + """ + The current output message on the standard output. + """ + if self._process.stdout is not None: + self._stdout += self._process.stdout.read() + return self._stdout + + def __repr__(self): + status = self.status + + if status == -1: + status = "running" + + return ( + f"[{self.title} command, status code: {status}," + f" {'in progress.' if not self.is_done else 'finished.'} PID:" + f" {self._process.pid}]" + ) + + +def is_git_repo(folder: Union[str, Path]) -> bool: + """ + Check if the folder is the root or part of a git repository + + Args: + folder (`str`): + The folder in which to run the command. + + Returns: + `bool`: `True` if the repository is part of a repository, `False` + otherwise. + """ + folder_exists = os.path.exists(os.path.join(folder, ".git")) + git_branch = subprocess.run("git branch".split(), cwd=folder, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return folder_exists and git_branch.returncode == 0 + + +def is_local_clone(folder: Union[str, Path], remote_url: str) -> bool: + """ + Check if the folder is a local clone of the remote_url + + Args: + folder (`str` or `Path`): + The folder in which to run the command. + remote_url (`str`): + The url of a git repository. + + Returns: + `bool`: `True` if the repository is a local clone of the remote + repository specified, `False` otherwise. + """ + if not is_git_repo(folder): + return False + + remotes = run_subprocess("git remote -v", folder).stdout + + # Remove token for the test with remotes. + remote_url = re.sub(r"https://.*@", "https://", remote_url) + remotes = [re.sub(r"https://.*@", "https://", remote) for remote in remotes.split()] + return remote_url in remotes + + +def is_tracked_with_lfs(filename: Union[str, Path]) -> bool: + """ + Check if the file passed is tracked with git-lfs. + + Args: + filename (`str` or `Path`): + The filename to check. + + Returns: + `bool`: `True` if the file passed is tracked with git-lfs, `False` + otherwise. + """ + folder = Path(filename).parent + filename = Path(filename).name + + try: + p = run_subprocess("git check-attr -a".split() + [filename], folder) + attributes = p.stdout.strip() + except subprocess.CalledProcessError as exc: + if not is_git_repo(folder): + return False + else: + raise OSError(exc.stderr) + + if len(attributes) == 0: + return False + + found_lfs_tag = {"diff": False, "merge": False, "filter": False} + + for attribute in attributes.split("\n"): + for tag in found_lfs_tag.keys(): + if tag in attribute and "lfs" in attribute: + found_lfs_tag[tag] = True + + return all(found_lfs_tag.values()) + + +def is_git_ignored(filename: Union[str, Path]) -> bool: + """ + Check if file is git-ignored. Supports nested .gitignore files. + + Args: + filename (`str` or `Path`): + The filename to check. + + Returns: + `bool`: `True` if the file passed is ignored by `git`, `False` + otherwise. + """ + folder = Path(filename).parent + filename = Path(filename).name + + try: + p = run_subprocess("git check-ignore".split() + [filename], folder, check=False) + # Will return exit code 1 if not gitignored + is_ignored = not bool(p.returncode) + except subprocess.CalledProcessError as exc: + raise OSError(exc.stderr) + + return is_ignored + + +def is_binary_file(filename: Union[str, Path]) -> bool: + """ + Check if file is a binary file. + + Args: + filename (`str` or `Path`): + The filename to check. + + Returns: + `bool`: `True` if the file passed is a binary file, `False` otherwise. + """ + try: + with open(filename, "rb") as f: + content = f.read(10 * (1024**2)) # Read a maximum of 10MB + + # Code sample taken from the following stack overflow thread + # https://stackoverflow.com/questions/898669/how-can-i-detect-if-a-file-is-binary-non-text-in-python/7392391#7392391 + text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F}) + return bool(content.translate(None, text_chars)) + except UnicodeDecodeError: + return True + + +def files_to_be_staged(pattern: str = ".", folder: Union[str, Path, None] = None) -> List[str]: + """ + Returns a list of filenames that are to be staged. + + Args: + pattern (`str` or `Path`): + The pattern of filenames to check. Put `.` to get all files. + folder (`str` or `Path`): + The folder in which to run the command. + + Returns: + `List[str]`: List of files that are to be staged. + """ + try: + p = run_subprocess("git ls-files --exclude-standard -mo".split() + [pattern], folder) + if len(p.stdout.strip()): + files = p.stdout.strip().split("\n") + else: + files = [] + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + return files + + +def is_tracked_upstream(folder: Union[str, Path]) -> bool: + """ + Check if the current checked-out branch is tracked upstream. + + Args: + folder (`str` or `Path`): + The folder in which to run the command. + + Returns: + `bool`: `True` if the current checked-out branch is tracked upstream, + `False` otherwise. + """ + try: + run_subprocess("git rev-parse --symbolic-full-name --abbrev-ref @{u}", folder) + return True + except subprocess.CalledProcessError as exc: + if "HEAD" in exc.stderr: + raise OSError("No branch checked out") + + return False + + +def commits_to_push(folder: Union[str, Path], upstream: Optional[str] = None) -> int: + """ + Check the number of commits that would be pushed upstream + + Args: + folder (`str` or `Path`): + The folder in which to run the command. + upstream (`str`, *optional*): + The name of the upstream repository with which the comparison should be + made. + + Returns: + `int`: Number of commits that would be pushed upstream were a `git + push` to proceed. + """ + try: + result = run_subprocess(f"git cherry -v {upstream or ''}", folder) + return len(result.stdout.split("\n")) - 1 + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + +class PbarT(TypedDict): + # Used to store an opened progress bar in `_lfs_log_progress` + bar: tqdm + past_bytes: int + + +@contextmanager +def _lfs_log_progress(): + """ + This is a context manager that will log the Git LFS progress of cleaning, + smudging, pulling and pushing. + """ + + if logger.getEffectiveLevel() >= logging.ERROR: + try: + yield + except Exception: + pass + return + + def output_progress(stopping_event: threading.Event): + """ + To be launched as a separate thread with an event meaning it should stop + the tail. + """ + # Key is tuple(state, filename), value is a dict(tqdm bar and a previous value) + pbars: Dict[Tuple[str, str], PbarT] = {} + + def close_pbars(): + for pbar in pbars.values(): + pbar["bar"].update(pbar["bar"].total - pbar["past_bytes"]) + pbar["bar"].refresh() + pbar["bar"].close() + + def tail_file(filename) -> Iterator[str]: + """ + Creates a generator to be iterated through, which will return each + line one by one. Will stop tailing the file if the stopping_event is + set. + """ + with open(filename, "r") as file: + current_line = "" + while True: + if stopping_event.is_set(): + close_pbars() + break + + line_bit = file.readline() + if line_bit is not None and not len(line_bit.strip()) == 0: + current_line += line_bit + if current_line.endswith("\n"): + yield current_line + current_line = "" + else: + time.sleep(1) + + # If the file isn't created yet, wait for a few seconds before trying again. + # Can be interrupted with the stopping_event. + while not os.path.exists(os.environ["GIT_LFS_PROGRESS"]): + if stopping_event.is_set(): + close_pbars() + return + + time.sleep(2) + + for line in tail_file(os.environ["GIT_LFS_PROGRESS"]): + try: + state, file_progress, byte_progress, filename = line.split() + except ValueError as error: + # Try/except to ease debugging. See https://github.com/huggingface/huggingface_hub/issues/1373. + raise ValueError(f"Cannot unpack LFS progress line:\n{line}") from error + description = f"{state.capitalize()} file {filename}" + + current_bytes, total_bytes = byte_progress.split("/") + current_bytes_int = int(current_bytes) + total_bytes_int = int(total_bytes) + + pbar = pbars.get((state, filename)) + if pbar is None: + # Initialize progress bar + pbars[(state, filename)] = { + "bar": tqdm( + desc=description, + initial=current_bytes_int, + total=total_bytes_int, + unit="B", + unit_scale=True, + unit_divisor=1024, + name="huggingface_hub.lfs_upload", + ), + "past_bytes": int(current_bytes), + } + else: + # Update progress bar + pbar["bar"].update(current_bytes_int - pbar["past_bytes"]) + pbar["past_bytes"] = current_bytes_int + + current_lfs_progress_value = os.environ.get("GIT_LFS_PROGRESS", "") + + with SoftTemporaryDirectory() as tmpdir: + os.environ["GIT_LFS_PROGRESS"] = os.path.join(tmpdir, "lfs_progress") + logger.debug(f"Following progress in {os.environ['GIT_LFS_PROGRESS']}") + + exit_event = threading.Event() + x = threading.Thread(target=output_progress, args=(exit_event,), daemon=True) + x.start() + + try: + yield + finally: + exit_event.set() + x.join() + + os.environ["GIT_LFS_PROGRESS"] = current_lfs_progress_value + + +class Repository: + """ + Helper class to wrap the git and git-lfs commands. + + The aim is to facilitate interacting with huggingface.co hosted model or + dataset repos, though not a lot here (if any) is actually specific to + huggingface.co. + + + + [`Repository`] is deprecated in favor of the http-based alternatives implemented in + [`HfApi`]. Given its large adoption in legacy code, the complete removal of + [`Repository`] will only happen in release `v1.0`. For more details, please read + https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http. + + + """ + + command_queue: List[CommandInProgress] + + @validate_hf_hub_args + @_deprecate_method( + version="1.0", + message=( + "Please prefer the http-based alternatives instead. Given its large adoption in legacy code, the complete" + " removal is only planned on next major release.\nFor more details, please read" + " https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http." + ), + ) + def __init__( + self, + local_dir: Union[str, Path], + clone_from: Optional[str] = None, + repo_type: Optional[str] = None, + token: Union[bool, str] = True, + git_user: Optional[str] = None, + git_email: Optional[str] = None, + revision: Optional[str] = None, + skip_lfs_files: bool = False, + client: Optional[HfApi] = None, + ): + """ + Instantiate a local clone of a git repo. + + If `clone_from` is set, the repo will be cloned from an existing remote repository. + If the remote repo does not exist, a `EnvironmentError` exception will be thrown. + Please create the remote repo first using [`create_repo`]. + + `Repository` uses the local git credentials by default. If explicitly set, the `token` + or the `git_user`/`git_email` pair will be used instead. + + Args: + local_dir (`str` or `Path`): + path (e.g. `'my_trained_model/'`) to the local directory, where + the `Repository` will be initialized. + clone_from (`str`, *optional*): + Either a repository url or `repo_id`. + Example: + - `"https://huggingface.co/philschmid/playground-tests"` + - `"philschmid/playground-tests"` + repo_type (`str`, *optional*): + To set when cloning a repo from a repo_id. Default is model. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). + If `None` or `True` and machine is logged in (through `huggingface-cli login` + or [`~huggingface_hub.login`]), token will be retrieved from the cache. + If `False`, token is not sent in the request header. + git_user (`str`, *optional*): + will override the `git config user.name` for committing and + pushing files to the hub. + git_email (`str`, *optional*): + will override the `git config user.email` for committing and + pushing files to the hub. + revision (`str`, *optional*): + Revision to checkout after initializing the repository. If the + revision doesn't exist, a branch will be created with that + revision name from the default branch's current HEAD. + skip_lfs_files (`bool`, *optional*, defaults to `False`): + whether to skip git-LFS files or not. + client (`HfApi`, *optional*): + Instance of [`HfApi`] to use when calling the HF Hub API. A new + instance will be created if this is left to `None`. + + Raises: + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If the remote repository set in `clone_from` does not exist. + """ + if isinstance(local_dir, Path): + local_dir = str(local_dir) + os.makedirs(local_dir, exist_ok=True) + self.local_dir = os.path.join(os.getcwd(), local_dir) + self._repo_type = repo_type + self.command_queue = [] + self.skip_lfs_files = skip_lfs_files + self.client = client if client is not None else HfApi() + + self.check_git_versions() + + if isinstance(token, str): + self.huggingface_token: Optional[str] = token + elif token is False: + self.huggingface_token = None + else: + # if `True` -> explicit use of the cached token + # if `None` -> implicit use of the cached token + self.huggingface_token = get_token() + + if clone_from is not None: + self.clone_from(repo_url=clone_from) + else: + if is_git_repo(self.local_dir): + logger.debug("[Repository] is a valid git repo") + else: + raise ValueError("If not specifying `clone_from`, you need to pass Repository a valid git clone.") + + if self.huggingface_token is not None and (git_email is None or git_user is None): + user = self.client.whoami(self.huggingface_token) + + if git_email is None: + git_email = user.get("email") + + if git_user is None: + git_user = user.get("fullname") + + if git_user is not None or git_email is not None: + self.git_config_username_and_email(git_user, git_email) + + self.lfs_enable_largefiles() + self.git_credential_helper_store() + + if revision is not None: + self.git_checkout(revision, create_branch_ok=True) + + # This ensures that all commands exit before exiting the Python runtime. + # This will ensure all pushes register on the hub, even if other errors happen in subsequent operations. + atexit.register(self.wait_for_commands) + + @property + def current_branch(self) -> str: + """ + Returns the current checked out branch. + + Returns: + `str`: Current checked out branch. + """ + try: + result = run_subprocess("git rev-parse --abbrev-ref HEAD", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + return result + + def check_git_versions(self): + """ + Checks that `git` and `git-lfs` can be run. + + Raises: + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `git` or `git-lfs` are not installed. + """ + try: + git_version = run_subprocess("git --version", self.local_dir).stdout.strip() + except FileNotFoundError: + raise EnvironmentError("Looks like you do not have git installed, please install.") + + try: + lfs_version = run_subprocess("git-lfs --version", self.local_dir).stdout.strip() + except FileNotFoundError: + raise EnvironmentError( + "Looks like you do not have git-lfs installed, please install." + " You can install from https://git-lfs.github.com/." + " Then run `git lfs install` (you only have to do this once)." + ) + logger.info(git_version + "\n" + lfs_version) + + @validate_hf_hub_args + def clone_from(self, repo_url: str, token: Union[bool, str, None] = None): + """ + Clone from a remote. If the folder already exists, will try to clone the + repository within it. + + If this folder is a git repository with linked history, will try to + update the repository. + + Args: + repo_url (`str`): + The URL from which to clone the repository + token (`Union[str, bool]`, *optional*): + Whether to use the authentication token. It can be: + - a string which is the token itself + - `False`, which would not use the authentication token + - `True`, which would fetch the authentication token from the + local folder and use it (you should be logged in for this to + work). + - `None`, which would retrieve the value of + `self.huggingface_token`. + ++ + Raises the following error: + + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if an organization token (starts with "api_org") is passed. Use must use + your own personal access token (see https://hf.co/settings/tokens). + + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if you are trying to clone the repository in a non-empty folder, or if the + `git` operations raise errors. + + + """ + token = ( + token # str -> use it + if isinstance(token, str) + else ( + None # `False` -> explicit no token + if token is False + else self.huggingface_token # `None` or `True` -> use default + ) + ) + if token is not None and token.startswith("api_org"): + raise ValueError( + "You must use your personal access token, not an Organization token" + " (see https://hf.co/settings/tokens)." + ) + + hub_url = self.client.endpoint + if hub_url in repo_url or ("http" not in repo_url and len(repo_url.split("/")) <= 2): + repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(repo_url, hub_url=hub_url) + repo_id = f"{namespace}/{repo_name}" if namespace is not None else repo_name + + if repo_type is not None: + self._repo_type = repo_type + + repo_url = hub_url + "/" + + if self._repo_type in constants.REPO_TYPES_URL_PREFIXES: + repo_url += constants.REPO_TYPES_URL_PREFIXES[self._repo_type] + + if token is not None: + # Add token in git url when provided + scheme = urlparse(repo_url).scheme + repo_url = repo_url.replace(f"{scheme}://", f"{scheme}://user:{token}@") + + repo_url += repo_id + + # For error messages, it's cleaner to show the repo url without the token. + clean_repo_url = re.sub(r"(https?)://.*@", r"\1://", repo_url) + try: + run_subprocess("git lfs install", self.local_dir) + + # checks if repository is initialized in a empty repository or in one with files + if len(os.listdir(self.local_dir)) == 0: + logger.warning(f"Cloning {clean_repo_url} into local empty directory.") + + with _lfs_log_progress(): + env = os.environ.copy() + + if self.skip_lfs_files: + env.update({"GIT_LFS_SKIP_SMUDGE": "1"}) + + run_subprocess( + # 'git lfs clone' is deprecated (will display a warning in the terminal) + # but we still use it as it provides a nicer UX when downloading large + # files (shows progress). + f"{'git clone' if self.skip_lfs_files else 'git lfs clone'} {repo_url} .", + self.local_dir, + env=env, + ) + else: + # Check if the folder is the root of a git repository + if not is_git_repo(self.local_dir): + raise EnvironmentError( + "Tried to clone a repository in a non-empty folder that isn't" + f" a git repository ('{self.local_dir}'). If you really want to" + f" do this, do it manually:\n cd {self.local_dir} && git init" + " && git remote add origin && git pull origin main\n or clone" + " repo to a new folder and move your existing files there" + " afterwards." + ) + + if is_local_clone(self.local_dir, repo_url): + logger.warning( + f"{self.local_dir} is already a clone of {clean_repo_url}." + " Make sure you pull the latest changes with" + " `repo.git_pull()`." + ) + else: + output = run_subprocess("git remote get-url origin", self.local_dir, check=False) + + error_msg = ( + f"Tried to clone {clean_repo_url} in an unrelated git" + " repository.\nIf you believe this is an error, please add" + f" a remote with the following URL: {clean_repo_url}." + ) + if output.returncode == 0: + clean_local_remote_url = re.sub(r"https://.*@", "https://", output.stdout) + error_msg += f"\nLocal path has its origin defined as: {clean_local_remote_url}" + raise EnvironmentError(error_msg) + + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_config_username_and_email(self, git_user: Optional[str] = None, git_email: Optional[str] = None): + """ + Sets git username and email (only in the current repo). + + Args: + git_user (`str`, *optional*): + The username to register through `git`. + git_email (`str`, *optional*): + The email to register through `git`. + """ + try: + if git_user is not None: + run_subprocess("git config user.name".split() + [git_user], self.local_dir) + + if git_email is not None: + run_subprocess(f"git config user.email {git_email}".split(), self.local_dir) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_credential_helper_store(self): + """ + Sets the git credential helper to `store` + """ + try: + run_subprocess("git config credential.helper store", self.local_dir) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_head_hash(self) -> str: + """ + Get commit sha on top of HEAD. + + Returns: + `str`: The current checked out commit SHA. + """ + try: + p = run_subprocess("git rev-parse HEAD", self.local_dir) + return p.stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_remote_url(self) -> str: + """ + Get URL to origin remote. + + Returns: + `str`: The URL of the `origin` remote. + """ + try: + p = run_subprocess("git config --get remote.origin.url", self.local_dir) + url = p.stdout.strip() + # Strip basic auth info. + return re.sub(r"https://.*@", "https://", url) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_head_commit_url(self) -> str: + """ + Get URL to last commit on HEAD. We assume it's been pushed, and the url + scheme is the same one as for GitHub or HuggingFace. + + Returns: + `str`: The URL to the current checked-out commit. + """ + sha = self.git_head_hash() + url = self.git_remote_url() + if url.endswith("/"): + url = url[:-1] + return f"{url}/commit/{sha}" + + def list_deleted_files(self) -> List[str]: + """ + Returns a list of the files that are deleted in the working directory or + index. + + Returns: + `List[str]`: A list of files that have been deleted in the working + directory or index. + """ + try: + git_status = run_subprocess("git status -s", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + if len(git_status) == 0: + return [] + + # Receives a status like the following + # D .gitignore + # D new_file.json + # AD new_file1.json + # ?? new_file2.json + # ?? new_file4.json + + # Strip each line of whitespaces + modified_files_statuses = [status.strip() for status in git_status.split("\n")] + + # Only keep files that are deleted using the D prefix + deleted_files_statuses = [status for status in modified_files_statuses if "D" in status.split()[0]] + + # Remove the D prefix and strip to keep only the relevant filename + deleted_files = [status.split()[-1].strip() for status in deleted_files_statuses] + + return deleted_files + + def lfs_track(self, patterns: Union[str, List[str]], filename: bool = False): + """ + Tell git-lfs to track files according to a pattern. + + Setting the `filename` argument to `True` will treat the arguments as + literal filenames, not as patterns. Any special glob characters in the + filename will be escaped when writing to the `.gitattributes` file. + + Args: + patterns (`Union[str, List[str]]`): + The pattern, or list of patterns, to track with git-lfs. + filename (`bool`, *optional*, defaults to `False`): + Whether to use the patterns as literal filenames. + """ + if isinstance(patterns, str): + patterns = [patterns] + try: + for pattern in patterns: + run_subprocess( + f"git lfs track {'--filename' if filename else ''} {pattern}", + self.local_dir, + ) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def lfs_untrack(self, patterns: Union[str, List[str]]): + """ + Tell git-lfs to untrack those files. + + Args: + patterns (`Union[str, List[str]]`): + The pattern, or list of patterns, to untrack with git-lfs. + """ + if isinstance(patterns, str): + patterns = [patterns] + try: + for pattern in patterns: + run_subprocess("git lfs untrack".split() + [pattern], self.local_dir) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def lfs_enable_largefiles(self): + """ + HF-specific. This enables upload support of files >5GB. + """ + try: + lfs_config = "git config lfs.customtransfer.multipart" + run_subprocess(f"{lfs_config}.path huggingface-cli", self.local_dir) + run_subprocess( + f"{lfs_config}.args {LFS_MULTIPART_UPLOAD_COMMAND}", + self.local_dir, + ) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def auto_track_binary_files(self, pattern: str = ".") -> List[str]: + """ + Automatically track binary files with git-lfs. + + Args: + pattern (`str`, *optional*, defaults to "."): + The pattern with which to track files that are binary. + + Returns: + `List[str]`: List of filenames that are now tracked due to being + binary files + """ + files_to_be_tracked_with_lfs = [] + + deleted_files = self.list_deleted_files() + + for filename in files_to_be_staged(pattern, folder=self.local_dir): + if filename in deleted_files: + continue + + path_to_file = os.path.join(os.getcwd(), self.local_dir, filename) + + if not (is_tracked_with_lfs(path_to_file) or is_git_ignored(path_to_file)): + size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024) + + if size_in_mb >= 10: + logger.warning( + "Parsing a large file to check if binary or not. Tracking large" + " files using `repository.auto_track_large_files` is" + " recommended so as to not load the full file in memory." + ) + + is_binary = is_binary_file(path_to_file) + + if is_binary: + self.lfs_track(filename) + files_to_be_tracked_with_lfs.append(filename) + + # Cleanup the .gitattributes if files were deleted + self.lfs_untrack(deleted_files) + + return files_to_be_tracked_with_lfs + + def auto_track_large_files(self, pattern: str = ".") -> List[str]: + """ + Automatically track large files (files that weigh more than 10MBs) with + git-lfs. + + Args: + pattern (`str`, *optional*, defaults to "."): + The pattern with which to track files that are above 10MBs. + + Returns: + `List[str]`: List of filenames that are now tracked due to their + size. + """ + files_to_be_tracked_with_lfs = [] + + deleted_files = self.list_deleted_files() + + for filename in files_to_be_staged(pattern, folder=self.local_dir): + if filename in deleted_files: + continue + + path_to_file = os.path.join(os.getcwd(), self.local_dir, filename) + size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024) + + if size_in_mb >= 10 and not is_tracked_with_lfs(path_to_file) and not is_git_ignored(path_to_file): + self.lfs_track(filename) + files_to_be_tracked_with_lfs.append(filename) + + # Cleanup the .gitattributes if files were deleted + self.lfs_untrack(deleted_files) + + return files_to_be_tracked_with_lfs + + def lfs_prune(self, recent=False): + """ + git lfs prune + + Args: + recent (`bool`, *optional*, defaults to `False`): + Whether to prune files even if they were referenced by recent + commits. See the following + [link](https://github.com/git-lfs/git-lfs/blob/f3d43f0428a84fc4f1e5405b76b5a73ec2437e65/docs/man/git-lfs-prune.1.ronn#recent-files) + for more information. + """ + try: + with _lfs_log_progress(): + result = run_subprocess(f"git lfs prune {'--recent' if recent else ''}", self.local_dir) + logger.info(result.stdout) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_pull(self, rebase: bool = False, lfs: bool = False): + """ + git pull + + Args: + rebase (`bool`, *optional*, defaults to `False`): + Whether to rebase the current branch on top of the upstream + branch after fetching. + lfs (`bool`, *optional*, defaults to `False`): + Whether to fetch the LFS files too. This option only changes the + behavior when a repository was cloned without fetching the LFS + files; calling `repo.git_pull(lfs=True)` will then fetch the LFS + file from the remote repository. + """ + command = "git pull" if not lfs else "git lfs pull" + if rebase: + command += " --rebase" + try: + with _lfs_log_progress(): + result = run_subprocess(command, self.local_dir) + logger.info(result.stdout) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_add(self, pattern: str = ".", auto_lfs_track: bool = False): + """ + git add + + Setting the `auto_lfs_track` parameter to `True` will automatically + track files that are larger than 10MB with `git-lfs`. + + Args: + pattern (`str`, *optional*, defaults to "."): + The pattern with which to add files to staging. + auto_lfs_track (`bool`, *optional*, defaults to `False`): + Whether to automatically track large and binary files with + git-lfs. Any file over 10MB in size, or in binary format, will + be automatically tracked. + """ + if auto_lfs_track: + # Track files according to their size (>=10MB) + tracked_files = self.auto_track_large_files(pattern) + + # Read the remaining files and track them if they're binary + tracked_files.extend(self.auto_track_binary_files(pattern)) + + if tracked_files: + logger.warning( + f"Adding files tracked by Git LFS: {tracked_files}. This may take a" + " bit of time if the files are large." + ) + + try: + result = run_subprocess("git add -v".split() + [pattern], self.local_dir) + logger.info(f"Adding to index:\n{result.stdout}\n") + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def git_commit(self, commit_message: str = "commit files to HF hub"): + """ + git commit + + Args: + commit_message (`str`, *optional*, defaults to "commit files to HF hub"): + The message attributed to the commit. + """ + try: + result = run_subprocess("git commit -v -m".split() + [commit_message], self.local_dir) + logger.info(f"Committed:\n{result.stdout}\n") + except subprocess.CalledProcessError as exc: + if len(exc.stderr) > 0: + raise EnvironmentError(exc.stderr) + else: + raise EnvironmentError(exc.stdout) + + def git_push( + self, + upstream: Optional[str] = None, + blocking: bool = True, + auto_lfs_prune: bool = False, + ) -> Union[str, Tuple[str, CommandInProgress]]: + """ + git push + + If used without setting `blocking`, will return url to commit on remote + repo. If used with `blocking=True`, will return a tuple containing the + url to commit and the command object to follow for information about the + process. + + Args: + upstream (`str`, *optional*): + Upstream to which this should push. If not specified, will push + to the lastly defined upstream or to the default one (`origin + main`). + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the push has + finished. Setting this to `False` will return an + `CommandInProgress` object which has an `is_done` property. This + property will be set to `True` when the push is finished. + auto_lfs_prune (`bool`, *optional*, defaults to `False`): + Whether to automatically prune files once they have been pushed + to the remote. + """ + command = "git push" + + if upstream: + command += f" --set-upstream {upstream}" + + number_of_commits = commits_to_push(self.local_dir, upstream) + + if number_of_commits > 1: + logger.warning(f"Several commits ({number_of_commits}) will be pushed upstream.") + if blocking: + logger.warning("The progress bars may be unreliable.") + + try: + with _lfs_log_progress(): + process = subprocess.Popen( + command.split(), + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + encoding="utf-8", + cwd=self.local_dir, + ) + + if blocking: + stdout, stderr = process.communicate() + return_code = process.poll() + process.kill() + + if len(stderr): + logger.warning(stderr) + + if return_code: + raise subprocess.CalledProcessError(return_code, process.args, output=stdout, stderr=stderr) + + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + if not blocking: + + def status_method(): + status = process.poll() + if status is None: + return -1 + else: + return status + + command_in_progress = CommandInProgress( + "push", + is_done_method=lambda: process.poll() is not None, + status_method=status_method, + process=process, + post_method=self.lfs_prune if auto_lfs_prune else None, + ) + + self.command_queue.append(command_in_progress) + + return self.git_head_commit_url(), command_in_progress + + if auto_lfs_prune: + self.lfs_prune() + + return self.git_head_commit_url() + + def git_checkout(self, revision: str, create_branch_ok: bool = False): + """ + git checkout a given revision + + Specifying `create_branch_ok` to `True` will create the branch to the + given revision if that revision doesn't exist. + + Args: + revision (`str`): + The revision to checkout. + create_branch_ok (`str`, *optional*, defaults to `False`): + Whether creating a branch named with the `revision` passed at + the current checked-out reference if `revision` isn't an + existing revision is allowed. + """ + try: + result = run_subprocess(f"git checkout {revision}", self.local_dir) + logger.warning(f"Checked out {revision} from {self.current_branch}.") + logger.warning(result.stdout) + except subprocess.CalledProcessError as exc: + if not create_branch_ok: + raise EnvironmentError(exc.stderr) + else: + try: + result = run_subprocess(f"git checkout -b {revision}", self.local_dir) + logger.warning( + f"Revision `{revision}` does not exist. Created and checked out branch `{revision}`." + ) + logger.warning(result.stdout) + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def tag_exists(self, tag_name: str, remote: Optional[str] = None) -> bool: + """ + Check if a tag exists or not. + + Args: + tag_name (`str`): + The name of the tag to check. + remote (`str`, *optional*): + Whether to check if the tag exists on a remote. This parameter + should be the identifier of the remote. + + Returns: + `bool`: Whether the tag exists. + """ + if remote: + try: + result = run_subprocess(f"git ls-remote origin refs/tags/{tag_name}", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + return len(result) != 0 + else: + try: + git_tags = run_subprocess("git tag", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + git_tags = git_tags.split("\n") + return tag_name in git_tags + + def delete_tag(self, tag_name: str, remote: Optional[str] = None) -> bool: + """ + Delete a tag, both local and remote, if it exists + + Args: + tag_name (`str`): + The tag name to delete. + remote (`str`, *optional*): + The remote on which to delete the tag. + + Returns: + `bool`: `True` if deleted, `False` if the tag didn't exist. + If remote is not passed, will just be updated locally + """ + delete_locally = True + delete_remotely = True + + if not self.tag_exists(tag_name): + delete_locally = False + + if not self.tag_exists(tag_name, remote=remote): + delete_remotely = False + + if delete_locally: + try: + run_subprocess(["git", "tag", "-d", tag_name], self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + if remote and delete_remotely: + try: + run_subprocess(f"git push {remote} --delete {tag_name}", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + return True + + def add_tag(self, tag_name: str, message: Optional[str] = None, remote: Optional[str] = None): + """ + Add a tag at the current head and push it + + If remote is None, will just be updated locally + + If no message is provided, the tag will be lightweight. if a message is + provided, the tag will be annotated. + + Args: + tag_name (`str`): + The name of the tag to be added. + message (`str`, *optional*): + The message that accompanies the tag. The tag will turn into an + annotated tag if a message is passed. + remote (`str`, *optional*): + The remote on which to add the tag. + """ + if message: + tag_args = ["git", "tag", "-a", tag_name, "-m", message] + else: + tag_args = ["git", "tag", tag_name] + + try: + run_subprocess(tag_args, self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + if remote: + try: + run_subprocess(f"git push {remote} {tag_name}", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + def is_repo_clean(self) -> bool: + """ + Return whether or not the git status is clean or not + + Returns: + `bool`: `True` if the git status is clean, `False` otherwise. + """ + try: + git_status = run_subprocess("git status --porcelain", self.local_dir).stdout.strip() + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + return len(git_status) == 0 + + def push_to_hub( + self, + commit_message: str = "commit files to HF hub", + blocking: bool = True, + clean_ok: bool = True, + auto_lfs_prune: bool = False, + ) -> Union[None, str, Tuple[str, CommandInProgress]]: + """ + Helper to add, commit, and push files to remote repository on the + HuggingFace Hub. Will automatically track large files (>10MB). + + Args: + commit_message (`str`): + Message to use for the commit. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has + finished. + clean_ok (`bool`, *optional*, defaults to `True`): + If True, this function will return None if the repo is + untouched. Default behavior is to fail because the git command + fails. + auto_lfs_prune (`bool`, *optional*, defaults to `False`): + Whether to automatically prune files once they have been pushed + to the remote. + """ + if clean_ok and self.is_repo_clean(): + logger.info("Repo currently clean. Ignoring push_to_hub") + return None + self.git_add(auto_lfs_track=True) + self.git_commit(commit_message) + return self.git_push( + upstream=f"origin {self.current_branch}", + blocking=blocking, + auto_lfs_prune=auto_lfs_prune, + ) + + @contextmanager + def commit( + self, + commit_message: str, + branch: Optional[str] = None, + track_large_files: bool = True, + blocking: bool = True, + auto_lfs_prune: bool = False, + ): + """ + Context manager utility to handle committing to a repository. This + automatically tracks large files (>10Mb) with git-lfs. Set the + `track_large_files` argument to `False` if you wish to ignore that + behavior. + + Args: + commit_message (`str`): + Message to use for the commit. + branch (`str`, *optional*): + The branch on which the commit will appear. This branch will be + checked-out before any operation. + track_large_files (`bool`, *optional*, defaults to `True`): + Whether to automatically track large files or not. Will do so by + default. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has + finished. + auto_lfs_prune (`bool`, defaults to `True`): + Whether to automatically prune files once they have been pushed + to the remote. + + Examples: + + ```python + >>> with Repository( + ... "text-files", + ... clone_from="/text-files", + ... token=True, + >>> ).commit("My first file :)"): + ... with open("file.txt", "w+") as f: + ... f.write(json.dumps({"hey": 8})) + + >>> import torch + + >>> model = torch.nn.Transformer() + >>> with Repository( + ... "torch-model", + ... clone_from=" /torch-model", + ... token=True, + >>> ).commit("My cool model :)"): + ... torch.save(model.state_dict(), "model.pt") + ``` + + """ + + files_to_stage = files_to_be_staged(".", folder=self.local_dir) + + if len(files_to_stage): + files_in_msg = str(files_to_stage[:5])[:-1] + ", ...]" if len(files_to_stage) > 5 else str(files_to_stage) + logger.error( + "There exists some updated files in the local repository that are not" + f" committed: {files_in_msg}. This may lead to errors if checking out" + " a branch. These files and their modifications will be added to the" + " current commit." + ) + + if branch is not None: + self.git_checkout(branch, create_branch_ok=True) + + if is_tracked_upstream(self.local_dir): + logger.warning("Pulling changes ...") + self.git_pull(rebase=True) + else: + logger.warning(f"The current branch has no upstream branch. Will push to 'origin {self.current_branch}'") + + current_working_directory = os.getcwd() + os.chdir(os.path.join(current_working_directory, self.local_dir)) + + try: + yield self + finally: + self.git_add(auto_lfs_track=track_large_files) + + try: + self.git_commit(commit_message) + except OSError as e: + # If no changes are detected, there is nothing to commit. + if "nothing to commit" not in str(e): + raise e + + try: + self.git_push( + upstream=f"origin {self.current_branch}", + blocking=blocking, + auto_lfs_prune=auto_lfs_prune, + ) + except OSError as e: + # If no changes are detected, there is nothing to commit. + if "could not read Username" in str(e): + raise OSError("Couldn't authenticate user for push. Did you set `token` to `True`?") from e + else: + raise e + + os.chdir(current_working_directory) + + def repocard_metadata_load(self) -> Optional[Dict]: + filepath = os.path.join(self.local_dir, constants.REPOCARD_NAME) + if os.path.isfile(filepath): + return metadata_load(filepath) + return None + + def repocard_metadata_save(self, data: Dict) -> None: + return metadata_save(os.path.join(self.local_dir, constants.REPOCARD_NAME), data) + + @property + def commands_failed(self): + """ + Returns the asynchronous commands that failed. + """ + return [c for c in self.command_queue if c.status > 0] + + @property + def commands_in_progress(self): + """ + Returns the asynchronous commands that are currently in progress. + """ + return [c for c in self.command_queue if not c.is_done] + + def wait_for_commands(self): + """ + Blocking method: blocks all subsequent execution until all commands have + been processed. + """ + index = 0 + for command_failed in self.commands_failed: + logger.error(f"The {command_failed.title} command with PID {command_failed._process.pid} failed.") + logger.error(command_failed.stderr) + + while self.commands_in_progress: + if index % 10 == 0: + logger.warning( + f"Waiting for the following commands to finish before shutting down: {self.commands_in_progress}." + ) + + index += 1 + + time.sleep(1) diff --git a/huggingface_hub/serialization/__init__.py b/huggingface_hub/serialization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e30ce175cf215f78701be6879869b5f0e45db5b --- /dev/null +++ b/huggingface_hub/serialization/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: F401 +"""Contains helpers to serialize tensors.""" + +from ._base import StateDictSplit, split_state_dict_into_shards_factory +from ._tensorflow import get_tf_storage_size, split_tf_state_dict_into_shards +from ._torch import ( + get_torch_storage_id, + get_torch_storage_size, + save_torch_model, + save_torch_state_dict, + split_torch_state_dict_into_shards, +) diff --git a/huggingface_hub/serialization/_base.py b/huggingface_hub/serialization/_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c30df3c3245846d97cc98639862624ec088e74c0 --- /dev/null +++ b/huggingface_hub/serialization/_base.py @@ -0,0 +1,210 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains helpers to split tensors into shards.""" + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union + +from .. import logging + + +TensorT = TypeVar("TensorT") +TensorSizeFn_T = Callable[[TensorT], int] +StorageIDFn_T = Callable[[TensorT], Optional[Any]] + +MAX_SHARD_SIZE = "5GB" +SIZE_UNITS = { + "TB": 10**12, + "GB": 10**9, + "MB": 10**6, + "KB": 10**3, +} + + +logger = logging.get_logger(__file__) + + +@dataclass +class StateDictSplit: + is_sharded: bool = field(init=False) + metadata: Dict[str, Any] + filename_to_tensors: Dict[str, List[str]] + tensor_to_filename: Dict[str, str] + + def __post_init__(self): + self.is_sharded = len(self.filename_to_tensors) > 1 + + +def split_state_dict_into_shards_factory( + state_dict: Dict[str, TensorT], + *, + get_storage_size: TensorSizeFn_T, + filename_pattern: str, + get_storage_id: StorageIDFn_T = lambda tensor: None, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, Tensor]`): + The state dictionary to save. + get_storage_size (`Callable[[Tensor], int]`): + A function that returns the size of a tensor when saved on disk in bytes. + get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*): + A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the + same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage + during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + """ + storage_id_to_tensors: Dict[Any, List[str]] = {} + + shard_list: List[Dict[str, TensorT]] = [] + current_shard: Dict[str, TensorT] = {} + current_shard_size = 0 + total_size = 0 + + if isinstance(max_shard_size, str): + max_shard_size = parse_size_to_int(max_shard_size) + + for key, tensor in state_dict.items(): + # when bnb serialization is used the weights in the state dict can be strings + # check: https://github.com/huggingface/transformers/pull/24416 for more details + if isinstance(tensor, str): + logger.info("Skipping tensor %s as it is a string (bnb serialization)", key) + continue + + # If a `tensor` shares the same underlying storage as another tensor, we put `tensor` in the same `block` + storage_id = get_storage_id(tensor) + if storage_id is not None: + if storage_id in storage_id_to_tensors: + # We skip this tensor for now and will reassign to correct shard later + storage_id_to_tensors[storage_id].append(key) + continue + else: + # This is the first tensor with this storage_id, we create a new entry + # in the storage_id_to_tensors dict => we will assign the shard id later + storage_id_to_tensors[storage_id] = [key] + + # Compute tensor size + tensor_size = get_storage_size(tensor) + + # If this tensor is bigger than the maximal size, we put it in its own shard + if tensor_size > max_shard_size: + total_size += tensor_size + shard_list.append({key: tensor}) + continue + + # If this tensor is going to tip up over the maximal size, we split. + # Current shard already has some tensors, we add it to the list of shards and create a new one. + if current_shard_size + tensor_size > max_shard_size: + shard_list.append(current_shard) + current_shard = {} + current_shard_size = 0 + + # Add the tensor to the current shard + current_shard[key] = tensor + current_shard_size += tensor_size + total_size += tensor_size + + # Add the last shard + if len(current_shard) > 0: + shard_list.append(current_shard) + nb_shards = len(shard_list) + + # Loop over the tensors that share the same storage and assign them together + for storage_id, keys in storage_id_to_tensors.items(): + # Let's try to find the shard where the first tensor of this storage is and put all tensors in the same shard + for shard in shard_list: + if keys[0] in shard: + for key in keys: + shard[key] = state_dict[key] + break + + # If we only have one shard, we return it => no need to build the index + if nb_shards == 1: + filename = filename_pattern.format(suffix="") + return StateDictSplit( + metadata={"total_size": total_size}, + filename_to_tensors={filename: list(state_dict.keys())}, + tensor_to_filename={key: filename for key in state_dict.keys()}, + ) + + # Now that each tensor is assigned to a shard, let's assign a filename to each shard + tensor_name_to_filename = {} + filename_to_tensors = {} + for idx, shard in enumerate(shard_list): + filename = filename_pattern.format(suffix=f"-{idx+1:05d}-of-{nb_shards:05d}") + for key in shard: + tensor_name_to_filename[key] = filename + filename_to_tensors[filename] = list(shard.keys()) + + # Build the index and return + return StateDictSplit( + metadata={"total_size": total_size}, + filename_to_tensors=filename_to_tensors, + tensor_to_filename=tensor_name_to_filename, + ) + + +def parse_size_to_int(size_as_str: str) -> int: + """ + Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes). + + Supported units are "TB", "GB", "MB", "KB". + + Args: + size_as_str (`str`): The size to convert. Will be directly returned if an `int`. + + Example: + + ```py + >>> parse_size_to_int("5MB") + 5000000 + ``` + """ + size_as_str = size_as_str.strip() + + # Parse unit + unit = size_as_str[-2:].upper() + if unit not in SIZE_UNITS: + raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.") + multiplier = SIZE_UNITS[unit] + + # Parse value + try: + value = float(size_as_str[:-2].strip()) + except ValueError as e: + raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e + + return int(value * multiplier) diff --git a/huggingface_hub/serialization/_tensorflow.py b/huggingface_hub/serialization/_tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..59ed8110b28f4891d67e754fdfbfa47a26f85be1 --- /dev/null +++ b/huggingface_hub/serialization/_tensorflow.py @@ -0,0 +1,95 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains tensorflow-specific helpers.""" + +import math +import re +from typing import TYPE_CHECKING, Dict, Union + +from .. import constants +from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory + + +if TYPE_CHECKING: + import tensorflow as tf + + +def split_tf_state_dict_into_shards( + state_dict: Dict[str, "tf.Tensor"], + *, + filename_pattern: str = constants.TF2_WEIGHTS_FILE_PATTERN, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + ++ + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, Tensor]`): + The state dictionary to save. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"tf_model{suffix}.h5"`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + """ + return split_state_dict_into_shards_factory( + state_dict, + max_shard_size=max_shard_size, + filename_pattern=filename_pattern, + get_storage_size=get_tf_storage_size, + ) + + +def get_tf_storage_size(tensor: "tf.Tensor") -> int: + # Return `math.ceil` since dtype byte size can be a float (e.g., 0.125 for tf.bool). + # Better to overestimate than underestimate. + return math.ceil(tensor.numpy().size * _dtype_byte_size_tf(tensor.dtype)) + + +def _dtype_byte_size_tf(dtype) -> float: + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + Taken from https://github.com/huggingface/transformers/blob/74d9d0cebb0263a3f8ab9c280569170cc74651d0/src/transformers/modeling_tf_utils.py#L608. + NOTE: why not `tensor.numpy().nbytes`? + Example: + ```py + >>> _dtype_byte_size(tf.float32) + 4 + ``` + """ + import tensorflow as tf + + if dtype == tf.bool: + return 1 / 8 + bit_search = re.search(r"[^\d](\d+)$", dtype.name) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 diff --git a/huggingface_hub/serialization/_torch.py b/huggingface_hub/serialization/_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..3c8b2b33f3ed64c409647a41695f73cbff4147cb --- /dev/null +++ b/huggingface_hub/serialization/_torch.py @@ -0,0 +1,630 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains pytorch-specific helpers.""" + +import importlib +import json +import os +import re +from collections import defaultdict +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union + +from .. import constants, logging +from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory + + +logger = logging.get_logger(__file__) + +if TYPE_CHECKING: + import torch + + +def save_torch_model( + model: "torch.nn.Module", + save_directory: Union[str, Path], + *, + filename_pattern: Optional[str] = None, + force_contiguous: bool = True, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, + metadata: Optional[Dict[str, str]] = None, + safe_serialization: bool = True, +): + """ + Saves a given torch model to disk, handling sharding and shared tensors issues. + + See also [`save_torch_state_dict`] to save a state dict with more flexibility. + + For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors). + + The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are + saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard, + an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses + [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as + safetensors (the default). Otherwise, the shards are saved as pickle. + + Before saving the model, the `save_directory` is cleaned from any previous shard files. + ++ + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + model (`torch.nn.Module`): + The model to save on disk. + save_directory (`str` or `Path`): + The directory in which the model will be saved. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization` + parameter. + force_contiguous (`boolean`, *optional*): + Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the + model, but it could potentially change performance if the layout of the tensor was chosen specifically for + that reason. Defaults to `True`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + metadata (`Dict[str, str]`, *optional*): + Extra information to save along with the model. Some metadata will be added for each dropped tensors. + This information will not be enough to recover the entire shared structure but might help understanding + things. + safe_serialization (`bool`, *optional*): + Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. + Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed + in a future version. + + Example: + + ```py + >>> from huggingface_hub import save_torch_model + >>> model = ... # A PyTorch model + + # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors. + >>> save_torch_model(model, "path/to/folder") + + # Load model back + >>> from huggingface_hub import load_torch_model # TODO + >>> load_torch_model(model, "path/to/folder") + >>> + ``` + """ + save_torch_state_dict( + state_dict=model.state_dict(), + filename_pattern=filename_pattern, + force_contiguous=force_contiguous, + max_shard_size=max_shard_size, + metadata=metadata, + safe_serialization=safe_serialization, + save_directory=save_directory, + ) + + +def save_torch_state_dict( + state_dict: Dict[str, "torch.Tensor"], + save_directory: Union[str, Path], + *, + filename_pattern: Optional[str] = None, + force_contiguous: bool = True, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, + metadata: Optional[Dict[str, str]] = None, + safe_serialization: bool = True, +) -> None: + """ + Save a model state dictionary to the disk, handling sharding and shared tensors issues. + + See also [`save_torch_model`] to directly save a PyTorch model. + + For more information about tensor sharing, check out [this guide](https://huggingface.co/docs/safetensors/torch_shared_tensors). + + The model state dictionary is split into shards so that each shard is smaller than a given size. The shards are + saved in the `save_directory` with the given `filename_pattern`. If the model is too big to fit in a single shard, + an index file is saved in the `save_directory` to indicate where each tensor is saved. This helper uses + [`split_torch_state_dict_into_shards`] under the hood. If `safe_serialization` is `True`, the shards are saved as + safetensors (the default). Otherwise, the shards are saved as pickle. + + Before saving the model, the `save_directory` is cleaned from any previous shard files. + ++ + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, torch.Tensor]`): + The state dictionary to save. + save_directory (`str` or `Path`): + The directory in which the model will be saved. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"` or `pytorch_model{suffix}.bin` depending on `safe_serialization` + parameter. + force_contiguous (`boolean`, *optional*): + Forcing the state_dict to be saved as contiguous tensors. This has no effect on the correctness of the + model, but it could potentially change performance if the layout of the tensor was chosen specifically for + that reason. Defaults to `True`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + metadata (`Dict[str, str]`, *optional*): + Extra information to save along with the model. Some metadata will be added for each dropped tensors. + This information will not be enough to recover the entire shared structure but might help understanding + things. + safe_serialization (`bool`, *optional*): + Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle. + Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed + in a future version. + + Example: + + ```py + >>> from huggingface_hub import save_torch_state_dict + >>> model = ... # A PyTorch model + + # Save state dict to "path/to/folder". The model will be split into shards of 5GB each and saved as safetensors. + >>> state_dict = model_to_save.state_dict() + >>> save_torch_state_dict(state_dict, "path/to/folder") + ``` + """ + save_directory = str(save_directory) + + if filename_pattern is None: + filename_pattern = ( + constants.SAFETENSORS_WEIGHTS_FILE_PATTERN + if safe_serialization + else constants.PYTORCH_WEIGHTS_FILE_PATTERN + ) + + # Imports correct library + if safe_serialization: + try: + from safetensors.torch import save_file as save_file_fn + except ImportError as e: + raise ImportError( + "Please install `safetensors` to use safe serialization. " + "You can install it with `pip install safetensors`." + ) from e + + else: + from torch import save as save_file_fn # type: ignore[assignment] + + logger.warning( + "You are using unsafe serialization. Due to security reasons, it is recommended not to load " + "pickled models from untrusted sources. If you intend to share your model, we strongly recommend " + "using safe serialization by installing `safetensors` with `pip install safetensors`." + ) + + # Clean state dict for safetensors + if metadata is None: + metadata = {} + if safe_serialization: + state_dict = _clean_state_dict_for_safetensors(state_dict, metadata, force_contiguous=force_contiguous) + + # Split dict + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size + ) + + # Clean the folder from previous save + existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?") + for filename in os.listdir(save_directory): + if existing_files_regex.match(filename): + try: + logger.debug(f"Removing existing file '{filename}' from folder.") + os.remove(os.path.join(save_directory, filename)) + except Exception as e: + logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...") + + # Save each shard + per_file_metadata = {"format": "pt"} + if not state_dict_split.is_sharded: + per_file_metadata.update(metadata) + safe_file_kwargs = {"metadata": per_file_metadata} if safe_serialization else {} + for filename, tensors in state_dict_split.filename_to_tensors.items(): + shard = {tensor: state_dict[tensor] for tensor in tensors} + save_file_fn(shard, os.path.join(save_directory, filename), **safe_file_kwargs) + logger.debug(f"Shard saved to {filename}") + + # Save the index (if any) + if state_dict_split.is_sharded: + index_path = filename_pattern.format(suffix="") + ".index.json" + index = { + "metadata": {**state_dict_split.metadata, **metadata}, + "weight_map": state_dict_split.tensor_to_filename, + } + with open(os.path.join(save_directory, index_path), "w") as f: + json.dump(index, f, indent=2) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}). " + f"Model weighs have been saved in {len(state_dict_split.filename_to_tensors)} checkpoint shards. " + f"You can find where each parameters has been saved in the index located at {index_path}." + ) + + logger.info(f"Model weights successfully saved to {save_directory}!") + + +def split_torch_state_dict_into_shards( + state_dict: Dict[str, "torch.Tensor"], + *, + filename_pattern: str = constants.SAFETENSORS_WEIGHTS_FILE_PATTERN, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, +) -> StateDictSplit: + """ + Split a model state dictionary in shards so that each shard is smaller than a given size. + + The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization + made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we + have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not + [6+2+2GB], [6+2GB], [6GB]. + + ++ + To save a model state dictionary to the disk, see [`save_torch_state_dict`]. This helper uses + `split_torch_state_dict_into_shards` under the hood. + + + ++ + If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a + size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, torch.Tensor]`): + The state dictionary to save. + filename_pattern (`str`, *optional*): + The pattern to generate the files names in which the model will be saved. Pattern must be a string that + can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` + Defaults to `"model{suffix}.safetensors"`. + max_shard_size (`int` or `str`, *optional*): + The maximum size of each shard, in bytes. Defaults to 5GB. + + Returns: + [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. + + Example: + ```py + >>> import json + >>> import os + >>> from safetensors.torch import save_file as safe_save_file + >>> from huggingface_hub import split_torch_state_dict_into_shards + + >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): + ... state_dict_split = split_torch_state_dict_into_shards(state_dict) + ... for filename, tensors in state_dict_split.filename_to_tensors.items(): + ... shard = {tensor: state_dict[tensor] for tensor in tensors} + ... safe_save_file( + ... shard, + ... os.path.join(save_directory, filename), + ... metadata={"format": "pt"}, + ... ) + ... if state_dict_split.is_sharded: + ... index = { + ... "metadata": state_dict_split.metadata, + ... "weight_map": state_dict_split.tensor_to_filename, + ... } + ... with open(os.path.join(save_directory, "model.safetensors.index.json"), "w") as f: + ... f.write(json.dumps(index, indent=2)) + ``` + """ + return split_state_dict_into_shards_factory( + state_dict, + max_shard_size=max_shard_size, + filename_pattern=filename_pattern, + get_storage_size=get_torch_storage_size, + get_storage_id=get_torch_storage_id, + ) + + +def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: + """Returns a unique id for plain tensor + or a (potentially nested) Tuple of unique id for the flattened Tensor + if the input is a wrapper tensor subclass Tensor + """ + + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs) + + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + if tensor.device.type == "xla" and is_torch_tpu_available(): + # NOTE: xla tensors dont have storage + # use some other unique id to distinguish. + # this is a XLA tensor, it must be created using torch_xla's + # device. So the following import is safe: + import torch_xla + + unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor) + else: + unique_id = storage_ptr(tensor) + + return unique_id + + +def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", Union[int, Tuple[Any, ...]], int]: + """ + Return unique identifier to a tensor storage. + + Multiple different tensors can share the same underlying storage. For + example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is + guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with + non-overlapping lifetimes may have the same id. + + Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278. + """ + return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor) + + +def get_torch_storage_size(tensor: "torch.Tensor") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59 + """ + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs) + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + try: + return tensor.untyped_storage().nbytes() + except AttributeError: + # Fallback for torch==1.10 + try: + return tensor.storage().size() * _get_dtype_size(tensor.dtype) + except NotImplementedError: + # Fallback for meta storage + # On torch >=2.0 this is the tensor size + return tensor.nelement() * _get_dtype_size(tensor.dtype) + + +@lru_cache() +def is_torch_tpu_available(check_device=True): + """ + Checks if `torch_xla` is installed and potentially if a TPU is in the environment + + Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/utils/import_utils.py#L463. + """ + if importlib.util.find_spec("torch_xla") is not None: + if check_device: + # We need to check if `xla_device` can be found, will raise a RuntimeError if not + try: + import torch_xla.core.xla_model as xm + + _ = xm.xla_device() + return True + except RuntimeError: + return False + return True + return False + + +def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11. + """ + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + return _get_unique_id(tensor) + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + try: + return tensor.untyped_storage().data_ptr() + except Exception: + # Fallback for torch==1.10 + try: + return tensor.storage().data_ptr() + except NotImplementedError: + # Fallback for meta storage + return 0 + + +def _clean_state_dict_for_safetensors( + state_dict: Dict[str, "torch.Tensor"], metadata: Dict[str, str], force_contiguous: bool = True +): + """Remove shared tensors from state_dict and update metadata accordingly (for reloading). + + Warning: `state_dict` and `metadata` are mutated in-place! + + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L155. + """ + to_removes = _remove_duplicate_names(state_dict) + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if metadata is None: + metadata = {} + + if to_remove not in metadata: + # Do not override user data + metadata[to_remove] = kept_name + del state_dict[to_remove] + if force_contiguous: + state_dict = {k: v.contiguous() for k, v in state_dict.items()} + return state_dict + + +def _end_ptr(tensor: "torch.Tensor") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L23. + """ + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + _get_dtype_size(tensor.dtype) + else: + stop = tensor.data_ptr() + return stop + + +def _filter_shared_not_shared(tensors: List[Set[str]], state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L44 + """ + filtered_tensors = [] + for shared in tensors: + if len(shared) < 2: + filtered_tensors.append(shared) + continue + + areas = [] + for name in shared: + tensor = state_dict[name] + areas.append((tensor.data_ptr(), _end_ptr(tensor), name)) + areas.sort() + + _, last_stop, last_name = areas[0] + filtered_tensors.append({last_name}) + for start, stop, name in areas[1:]: + if start >= last_stop: + filtered_tensors.append({name}) + else: + filtered_tensors[-1].add(name) + last_stop = stop + + return filtered_tensors + + +def _find_shared_tensors(state_dict: Dict[str, "torch.Tensor"]) -> List[Set[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L69. + """ + import torch + + tensors_dict = defaultdict(set) + for k, v in state_dict.items(): + if v.device != torch.device("meta") and storage_ptr(v) != 0 and get_torch_storage_size(v) != 0: + # Need to add device as key because of multiple GPU. + tensors_dict[(v.device, storage_ptr(v), get_torch_storage_size(v))].add(k) + tensors = list(sorted(tensors_dict.values())) + tensors = _filter_shared_not_shared(tensors, state_dict) + return tensors + + +def _is_complete(tensor: "torch.Tensor") -> bool: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 + """ + try: + # for torch 2.1 and above we can also handle tensor subclasses + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if is_traceable_wrapper_subclass(tensor): + attrs, _ = tensor.__tensor_flatten__() # type: ignore[attr-defined] + return all(_is_complete(getattr(tensor, attr)) for attr in attrs) + except ImportError: + # for torch version less than 2.1, we can fallback to original implementation + pass + + return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size( + tensor.dtype + ) == get_torch_storage_size(tensor) + + +def _remove_duplicate_names( + state_dict: Dict[str, "torch.Tensor"], + *, + preferred_names: Optional[List[str]] = None, + discard_names: Optional[List[str]] = None, +) -> Dict[str, List[str]]: + """ + Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80 + """ + if preferred_names is None: + preferred_names = [] + unique_preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + unique_discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set([name for name in shared if _is_complete(state_dict[name])]) + if not complete_names: + raise RuntimeError( + "Error while trying to find names to remove to save state dict, but found no suitable name to keep" + f" for saving amongst: {shared}. None is covering the entire storage. Refusing to save/load the model" + " since you could be storing much more memory than needed. Please refer to" + " https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an" + " issue." + ) + + keep_name = sorted(list(complete_names))[0] + + # Mechanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(unique_discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if unique_preferred_names: + preferred = unique_preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + + +@lru_cache() +def _get_dtype_size(dtype: "torch.dtype") -> int: + """ + Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L344 + """ + import torch + + # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions + _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None) + _float8_e5m2 = getattr(torch, "float8_e5m2", None) + _SIZE = { + torch.int64: 8, + torch.float32: 4, + torch.int32: 4, + torch.bfloat16: 2, + torch.float16: 2, + torch.int16: 2, + torch.uint8: 1, + torch.int8: 1, + torch.bool: 1, + torch.float64: 8, + _float8_e4m3fn: 1, + _float8_e5m2: 1, + } + return _SIZE[dtype] diff --git a/huggingface_hub/templates/datasetcard_template.md b/huggingface_hub/templates/datasetcard_template.md new file mode 100644 index 0000000000000000000000000000000000000000..9af29ebbed93653ec74a8952e314e7554323ef15 --- /dev/null +++ b/huggingface_hub/templates/datasetcard_template.md @@ -0,0 +1,143 @@ +--- +# For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/datasets-cards +{{ card_data }} +--- + +# Dataset Card for {{ pretty_name | default("Dataset Name", true) }} + + + +{{ dataset_summary | default("", true) }} + +## Dataset Details + +### Dataset Description + + + +{{ dataset_description | default("", true) }} + +- **Curated by:** {{ curators | default("[More Information Needed]", true)}} +- **Funded by [optional]:** {{ funded_by | default("[More Information Needed]", true)}} +- **Shared by [optional]:** {{ shared_by | default("[More Information Needed]", true)}} +- **Language(s) (NLP):** {{ language | default("[More Information Needed]", true)}} +- **License:** {{ license | default("[More Information Needed]", true)}} + +### Dataset Sources [optional] + + + +- **Repository:** {{ repo | default("[More Information Needed]", true)}} +- **Paper [optional]:** {{ paper | default("[More Information Needed]", true)}} +- **Demo [optional]:** {{ demo | default("[More Information Needed]", true)}} + +## Uses + + + +### Direct Use + + + +{{ direct_use | default("[More Information Needed]", true)}} + +### Out-of-Scope Use + + + +{{ out_of_scope_use | default("[More Information Needed]", true)}} + +## Dataset Structure + + + +{{ dataset_structure | default("[More Information Needed]", true)}} + +## Dataset Creation + +### Curation Rationale + + + +{{ curation_rationale_section | default("[More Information Needed]", true)}} + +### Source Data + + + +#### Data Collection and Processing + + + +{{ data_collection_and_processing_section | default("[More Information Needed]", true)}} + +#### Who are the source data producers? + + + +{{ source_data_producers_section | default("[More Information Needed]", true)}} + +### Annotations [optional] + + + +#### Annotation process + + + +{{ annotation_process_section | default("[More Information Needed]", true)}} + +#### Who are the annotators? + + + +{{ who_are_annotators_section | default("[More Information Needed]", true)}} + +#### Personal and Sensitive Information + + + +{{ personal_and_sensitive_information | default("[More Information Needed]", true)}} + +## Bias, Risks, and Limitations + + + +{{ bias_risks_limitations | default("[More Information Needed]", true)}} + +### Recommendations + + + +{{ bias_recommendations | default("Users should be made aware of the risks, biases and limitations of the dataset. More information needed for further recommendations.", true)}} + +## Citation [optional] + + + +**BibTeX:** + +{{ citation_bibtex | default("[More Information Needed]", true)}} + +**APA:** + +{{ citation_apa | default("[More Information Needed]", true)}} + +## Glossary [optional] + + + +{{ glossary | default("[More Information Needed]", true)}} + +## More Information [optional] + +{{ more_information | default("[More Information Needed]", true)}} + +## Dataset Card Authors [optional] + +{{ dataset_card_authors | default("[More Information Needed]", true)}} + +## Dataset Card Contact + +{{ dataset_card_contact | default("[More Information Needed]", true)}} diff --git a/huggingface_hub/templates/modelcard_template.md b/huggingface_hub/templates/modelcard_template.md new file mode 100644 index 0000000000000000000000000000000000000000..79ca15e4547debac763b390ef8e4b715e6f6403f --- /dev/null +++ b/huggingface_hub/templates/modelcard_template.md @@ -0,0 +1,200 @@ +--- +# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/model-cards +{{ card_data }} +--- + +# Model Card for {{ model_id | default("Model ID", true) }} + + + +{{ model_summary | default("", true) }} + +## Model Details + +### Model Description + + + +{{ model_description | default("", true) }} + +- **Developed by:** {{ developers | default("[More Information Needed]", true)}} +- **Funded by [optional]:** {{ funded_by | default("[More Information Needed]", true)}} +- **Shared by [optional]:** {{ shared_by | default("[More Information Needed]", true)}} +- **Model type:** {{ model_type | default("[More Information Needed]", true)}} +- **Language(s) (NLP):** {{ language | default("[More Information Needed]", true)}} +- **License:** {{ license | default("[More Information Needed]", true)}} +- **Finetuned from model [optional]:** {{ base_model | default("[More Information Needed]", true)}} + +### Model Sources [optional] + + + +- **Repository:** {{ repo | default("[More Information Needed]", true)}} +- **Paper [optional]:** {{ paper | default("[More Information Needed]", true)}} +- **Demo [optional]:** {{ demo | default("[More Information Needed]", true)}} + +## Uses + + + +### Direct Use + + + +{{ direct_use | default("[More Information Needed]", true)}} + +### Downstream Use [optional] + + + +{{ downstream_use | default("[More Information Needed]", true)}} + +### Out-of-Scope Use + + + +{{ out_of_scope_use | default("[More Information Needed]", true)}} + +## Bias, Risks, and Limitations + + + +{{ bias_risks_limitations | default("[More Information Needed]", true)}} + +### Recommendations + + + +{{ bias_recommendations | default("Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.", true)}} + +## How to Get Started with the Model + +Use the code below to get started with the model. + +{{ get_started_code | default("[More Information Needed]", true)}} + +## Training Details + +### Training Data + + + +{{ training_data | default("[More Information Needed]", true)}} + +### Training Procedure + + + +#### Preprocessing [optional] + +{{ preprocessing | default("[More Information Needed]", true)}} + + +#### Training Hyperparameters + +- **Training regime:** {{ training_regime | default("[More Information Needed]", true)}} + +#### Speeds, Sizes, Times [optional] + + + +{{ speeds_sizes_times | default("[More Information Needed]", true)}} + +## Evaluation + + + +### Testing Data, Factors & Metrics + +#### Testing Data + + + +{{ testing_data | default("[More Information Needed]", true)}} + +#### Factors + + + +{{ testing_factors | default("[More Information Needed]", true)}} + +#### Metrics + + + +{{ testing_metrics | default("[More Information Needed]", true)}} + +### Results + +{{ results | default("[More Information Needed]", true)}} + +#### Summary + +{{ results_summary | default("", true) }} + +## Model Examination [optional] + + + +{{ model_examination | default("[More Information Needed]", true)}} + +## Environmental Impact + + + +Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). + +- **Hardware Type:** {{ hardware_type | default("[More Information Needed]", true)}} +- **Hours used:** {{ hours_used | default("[More Information Needed]", true)}} +- **Cloud Provider:** {{ cloud_provider | default("[More Information Needed]", true)}} +- **Compute Region:** {{ cloud_region | default("[More Information Needed]", true)}} +- **Carbon Emitted:** {{ co2_emitted | default("[More Information Needed]", true)}} + +## Technical Specifications [optional] + +### Model Architecture and Objective + +{{ model_specs | default("[More Information Needed]", true)}} + +### Compute Infrastructure + +{{ compute_infrastructure | default("[More Information Needed]", true)}} + +#### Hardware + +{{ hardware_requirements | default("[More Information Needed]", true)}} + +#### Software + +{{ software | default("[More Information Needed]", true)}} + +## Citation [optional] + + + +**BibTeX:** + +{{ citation_bibtex | default("[More Information Needed]", true)}} + +**APA:** + +{{ citation_apa | default("[More Information Needed]", true)}} + +## Glossary [optional] + + + +{{ glossary | default("[More Information Needed]", true)}} + +## More Information [optional] + +{{ more_information | default("[More Information Needed]", true)}} + +## Model Card Authors [optional] + +{{ model_card_authors | default("[More Information Needed]", true)}} + +## Model Card Contact + +{{ model_card_contact | default("[More Information Needed]", true)}} diff --git a/huggingface_hub/utils/__init__.py b/huggingface_hub/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4efea4e253772e03925f4632ed40b8d408c0aeec --- /dev/null +++ b/huggingface_hub/utils/__init__.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +# ruff: noqa: F401 + +from huggingface_hub.errors import ( + BadRequestError, + CacheNotFound, + CorruptedCacheException, + DisabledRepoError, + EntryNotFoundError, + FileMetadataError, + GatedRepoError, + HfHubHTTPError, + HFValidationError, + LocalEntryNotFoundError, + LocalTokenNotFoundError, + NotASafetensorsRepoError, + OfflineModeIsEnabled, + RepositoryNotFoundError, + RevisionNotFoundError, + SafetensorsParsingError, +) + +from . import tqdm as _tqdm # _tqdm is the module +from ._cache_assets import cached_assets_path +from ._cache_manager import ( + CachedFileInfo, + CachedRepoInfo, + CachedRevisionInfo, + DeleteCacheStrategy, + HFCacheInfo, + scan_cache_dir, +) +from ._chunk_utils import chunk_iterable +from ._datetime import parse_datetime +from ._experimental import experimental +from ._fixes import SoftTemporaryDirectory, WeakFileLock, yaml_dump +from ._git_credential import list_credential_helpers, set_git_credential, unset_git_credential +from ._headers import build_hf_headers, get_token_to_send +from ._hf_folder import HfFolder +from ._http import ( + configure_http_backend, + fix_hf_endpoint_in_url, + get_session, + hf_raise_for_status, + http_backoff, + reset_sessions, +) +from ._pagination import paginate +from ._paths import DEFAULT_IGNORE_PATTERNS, FORBIDDEN_FOLDERS, filter_repo_objects +from ._runtime import ( + dump_environment_info, + get_aiohttp_version, + get_fastai_version, + get_fastapi_version, + get_fastcore_version, + get_gradio_version, + get_graphviz_version, + get_hf_hub_version, + get_hf_transfer_version, + get_jinja_version, + get_minijinja_version, + get_numpy_version, + get_pillow_version, + get_pydantic_version, + get_pydot_version, + get_python_version, + get_tensorboard_version, + get_tf_version, + get_torch_version, + is_aiohttp_available, + is_colab_enterprise, + is_fastai_available, + is_fastapi_available, + is_fastcore_available, + is_google_colab, + is_gradio_available, + is_graphviz_available, + is_hf_transfer_available, + is_jinja_available, + is_minijinja_available, + is_notebook, + is_numpy_available, + is_package_available, + is_pillow_available, + is_pydantic_available, + is_pydot_available, + is_safetensors_available, + is_tensorboard_available, + is_tf_available, + is_torch_available, +) +from ._safetensors import ( + SafetensorsFileMetadata, + SafetensorsRepoMetadata, + TensorInfo, +) +from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess +from ._telemetry import send_telemetry +from ._token import get_token +from ._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type +from ._validators import ( + smoothly_deprecate_use_auth_token, + validate_hf_hub_args, + validate_repo_id, +) +from .tqdm import ( + are_progress_bars_disabled, + disable_progress_bars, + enable_progress_bars, + tqdm, + tqdm_stream_file, +) diff --git a/huggingface_hub/utils/_cache_assets.py b/huggingface_hub/utils/_cache_assets.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d435df9b0bb0c67c0bcb5ef65711e9aef367f6 --- /dev/null +++ b/huggingface_hub/utils/_cache_assets.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Union + +from ..constants import HF_ASSETS_CACHE + + +def cached_assets_path( + library_name: str, + namespace: str = "default", + subfolder: str = "default", + *, + assets_dir: Union[str, Path, None] = None, +): + """Return a folder path to cache arbitrary files. + + `huggingface_hub` provides a canonical folder path to store assets. This is the + recommended way to integrate cache in a downstream library as it will benefit from + the builtins tools to scan and delete the cache properly. + + The distinction is made between files cached from the Hub and assets. Files from the + Hub are cached in a git-aware manner and entirely managed by `huggingface_hub`. See + [related documentation](https://huggingface.co/docs/huggingface_hub/how-to-cache). + All other files that a downstream library caches are considered to be "assets" + (files downloaded from external sources, extracted from a .tar archive, preprocessed + for training,...). + + Once the folder path is generated, it is guaranteed to exist and to be a directory. + The path is based on 3 levels of depth: the library name, a namespace and a + subfolder. Those 3 levels grants flexibility while allowing `huggingface_hub` to + expect folders when scanning/deleting parts of the assets cache. Within a library, + it is expected that all namespaces share the same subset of subfolder names but this + is not a mandatory rule. The downstream library has then full control on which file + structure to adopt within its cache. Namespace and subfolder are optional (would + default to a `"default/"` subfolder) but library name is mandatory as we want every + downstream library to manage its own cache. + + Expected tree: + ```text + assets/ + βββ datasets/ + β βββ SQuAD/ + β β βββ downloaded/ + β β βββ extracted/ + β β βββ processed/ + β βββ Helsinki-NLP--tatoeba_mt/ + β βββ downloaded/ + β βββ extracted/ + β βββ processed/ + βββ transformers/ + βββ default/ + β βββ something/ + βββ bert-base-cased/ + β βββ default/ + β βββ training/ + hub/ + βββ models--julien-c--EsperBERTo-small/ + βββ blobs/ + β βββ (...) + β βββ (...) + βββ refs/ + β βββ (...) + βββ [ 128] snapshots/ + βββ 2439f60ef33a0d46d85da5001d52aeda5b00ce9f/ + β βββ (...) + βββ bbc77c8132af1cc5cf678da3f1ddf2de43606d48/ + βββ (...) + ``` + + + Args: + library_name (`str`): + Name of the library that will manage the cache folder. Example: `"dataset"`. + namespace (`str`, *optional*, defaults to "default"): + Namespace to which the data belongs. Example: `"SQuAD"`. + subfolder (`str`, *optional*, defaults to "default"): + Subfolder in which the data will be stored. Example: `extracted`. + assets_dir (`str`, `Path`, *optional*): + Path to the folder where assets are cached. This must not be the same folder + where Hub files are cached. Defaults to `HF_HOME / "assets"` if not provided. + Can also be set with `HF_ASSETS_CACHE` environment variable. + + Returns: + Path to the cache folder (`Path`). + + Example: + ```py + >>> from huggingface_hub import cached_assets_path + + >>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="download") + PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/download') + + >>> cached_assets_path(library_name="datasets", namespace="SQuAD", subfolder="extracted") + PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/SQuAD/extracted') + + >>> cached_assets_path(library_name="datasets", namespace="Helsinki-NLP/tatoeba_mt") + PosixPath('/home/wauplin/.cache/huggingface/extra/datasets/Helsinki-NLP--tatoeba_mt/default') + + >>> cached_assets_path(library_name="datasets", assets_dir="/tmp/tmp123456") + PosixPath('/tmp/tmp123456/datasets/default/default') + ``` + """ + # Resolve assets_dir + if assets_dir is None: + assets_dir = HF_ASSETS_CACHE + assets_dir = Path(assets_dir).expanduser().resolve() + + # Avoid names that could create path issues + for part in (" ", "/", "\\"): + library_name = library_name.replace(part, "--") + namespace = namespace.replace(part, "--") + subfolder = subfolder.replace(part, "--") + + # Path to subfolder is created + path = assets_dir / library_name / namespace / subfolder + try: + path.mkdir(exist_ok=True, parents=True) + except (FileExistsError, NotADirectoryError): + raise ValueError(f"Corrupted assets folder: cannot create directory because of an existing file ({path}).") + + # Return + return path diff --git a/huggingface_hub/utils/_cache_manager.py b/huggingface_hub/utils/_cache_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..3f3c2a9c8896213e8ef9ceabdfa3de1bdf1c0a2e --- /dev/null +++ b/huggingface_hub/utils/_cache_manager.py @@ -0,0 +1,896 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to manage the HF cache directory.""" + +import os +import shutil +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, FrozenSet, List, Literal, Optional, Set, Union + +from huggingface_hub.errors import CacheNotFound, CorruptedCacheException + +from ..commands._cli_utils import tabulate +from ..constants import HF_HUB_CACHE +from . import logging + + +logger = logging.get_logger(__name__) + +REPO_TYPE_T = Literal["model", "dataset", "space"] + +# List of OS-created helper files that need to be ignored +FILES_TO_IGNORE = [".DS_Store"] + + +@dataclass(frozen=True) +class CachedFileInfo: + """Frozen data structure holding information about a single cached file. + + Args: + file_name (`str`): + Name of the file. Example: `config.json`. + file_path (`Path`): + Path of the file in the `snapshots` directory. The file path is a symlink + referring to a blob in the `blobs` folder. + blob_path (`Path`): + Path of the blob file. This is equivalent to `file_path.resolve()`. + size_on_disk (`int`): + Size of the blob file in bytes. + blob_last_accessed (`float`): + Timestamp of the last time the blob file has been accessed (from any + revision). + blob_last_modified (`float`): + Timestamp of the last time the blob file has been modified/created. + ++ + `blob_last_accessed` and `blob_last_modified` reliability can depend on the OS you + are using. See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result) + for more details. + + + """ + + file_name: str + file_path: Path + blob_path: Path + size_on_disk: int + + blob_last_accessed: float + blob_last_modified: float + + @property + def blob_last_accessed_str(self) -> str: + """ + (property) Timestamp of the last time the blob file has been accessed (from any + revision), returned as a human-readable string. + + Example: "2 weeks ago". + """ + return _format_timesince(self.blob_last_accessed) + + @property + def blob_last_modified_str(self) -> str: + """ + (property) Timestamp of the last time the blob file has been modified, returned + as a human-readable string. + + Example: "2 weeks ago". + """ + return _format_timesince(self.blob_last_modified) + + @property + def size_on_disk_str(self) -> str: + """ + (property) Size of the blob file as a human-readable string. + + Example: "42.2K". + """ + return _format_size(self.size_on_disk) + + +@dataclass(frozen=True) +class CachedRevisionInfo: + """Frozen data structure holding information about a revision. + + A revision correspond to a folder in the `snapshots` folder and is populated with + the exact tree structure as the repo on the Hub but contains only symlinks. A + revision can be either referenced by 1 or more `refs` or be "detached" (no refs). + + Args: + commit_hash (`str`): + Hash of the revision (unique). + Example: `"9338f7b671827df886678df2bdd7cc7b4f36dffd"`. + snapshot_path (`Path`): + Path to the revision directory in the `snapshots` folder. It contains the + exact tree structure as the repo on the Hub. + files: (`FrozenSet[CachedFileInfo]`): + Set of [`~CachedFileInfo`] describing all files contained in the snapshot. + refs (`FrozenSet[str]`): + Set of `refs` pointing to this revision. If the revision has no `refs`, it + is considered detached. + Example: `{"main", "2.4.0"}` or `{"refs/pr/1"}`. + size_on_disk (`int`): + Sum of the blob file sizes that are symlink-ed by the revision. + last_modified (`float`): + Timestamp of the last time the revision has been created/modified. + ++ + `last_accessed` cannot be determined correctly on a single revision as blob files + are shared across revisions. + + + ++ + `size_on_disk` is not necessarily the sum of all file sizes because of possible + duplicated files. Besides, only blobs are taken into account, not the (negligible) + size of folders and symlinks. + + + """ + + commit_hash: str + snapshot_path: Path + size_on_disk: int + files: FrozenSet[CachedFileInfo] + refs: FrozenSet[str] + + last_modified: float + + @property + def last_modified_str(self) -> str: + """ + (property) Timestamp of the last time the revision has been modified, returned + as a human-readable string. + + Example: "2 weeks ago". + """ + return _format_timesince(self.last_modified) + + @property + def size_on_disk_str(self) -> str: + """ + (property) Sum of the blob file sizes as a human-readable string. + + Example: "42.2K". + """ + return _format_size(self.size_on_disk) + + @property + def nb_files(self) -> int: + """ + (property) Total number of files in the revision. + """ + return len(self.files) + + +@dataclass(frozen=True) +class CachedRepoInfo: + """Frozen data structure holding information about a cached repository. + + Args: + repo_id (`str`): + Repo id of the repo on the Hub. Example: `"google/fleurs"`. + repo_type (`Literal["dataset", "model", "space"]`): + Type of the cached repo. + repo_path (`Path`): + Local path to the cached repo. + size_on_disk (`int`): + Sum of the blob file sizes in the cached repo. + nb_files (`int`): + Total number of blob files in the cached repo. + revisions (`FrozenSet[CachedRevisionInfo]`): + Set of [`~CachedRevisionInfo`] describing all revisions cached in the repo. + last_accessed (`float`): + Timestamp of the last time a blob file of the repo has been accessed. + last_modified (`float`): + Timestamp of the last time a blob file of the repo has been modified/created. + ++ + `size_on_disk` is not necessarily the sum of all revisions sizes because of + duplicated files. Besides, only blobs are taken into account, not the (negligible) + size of folders and symlinks. + + + ++ + `last_accessed` and `last_modified` reliability can depend on the OS you are using. + See [python documentation](https://docs.python.org/3/library/os.html#os.stat_result) + for more details. + + + """ + + repo_id: str + repo_type: REPO_TYPE_T + repo_path: Path + size_on_disk: int + nb_files: int + revisions: FrozenSet[CachedRevisionInfo] + + last_accessed: float + last_modified: float + + @property + def last_accessed_str(self) -> str: + """ + (property) Last time a blob file of the repo has been accessed, returned as a + human-readable string. + + Example: "2 weeks ago". + """ + return _format_timesince(self.last_accessed) + + @property + def last_modified_str(self) -> str: + """ + (property) Last time a blob file of the repo has been modified, returned as a + human-readable string. + + Example: "2 weeks ago". + """ + return _format_timesince(self.last_modified) + + @property + def size_on_disk_str(self) -> str: + """ + (property) Sum of the blob file sizes as a human-readable string. + + Example: "42.2K". + """ + return _format_size(self.size_on_disk) + + @property + def refs(self) -> Dict[str, CachedRevisionInfo]: + """ + (property) Mapping between `refs` and revision data structures. + """ + return {ref: revision for revision in self.revisions for ref in revision.refs} + + +@dataclass(frozen=True) +class DeleteCacheStrategy: + """Frozen data structure holding the strategy to delete cached revisions. + + This object is not meant to be instantiated programmatically but to be returned by + [`~utils.HFCacheInfo.delete_revisions`]. See documentation for usage example. + + Args: + expected_freed_size (`float`): + Expected freed size once strategy is executed. + blobs (`FrozenSet[Path]`): + Set of blob file paths to be deleted. + refs (`FrozenSet[Path]`): + Set of reference file paths to be deleted. + repos (`FrozenSet[Path]`): + Set of entire repo paths to be deleted. + snapshots (`FrozenSet[Path]`): + Set of snapshots to be deleted (directory of symlinks). + """ + + expected_freed_size: int + blobs: FrozenSet[Path] + refs: FrozenSet[Path] + repos: FrozenSet[Path] + snapshots: FrozenSet[Path] + + @property + def expected_freed_size_str(self) -> str: + """ + (property) Expected size that will be freed as a human-readable string. + + Example: "42.2K". + """ + return _format_size(self.expected_freed_size) + + def execute(self) -> None: + """Execute the defined strategy. + ++ + If this method is interrupted, the cache might get corrupted. Deletion order is + implemented so that references and symlinks are deleted before the actual blob + files. + + + ++ + This method is irreversible. If executed, cached files are erased and must be + downloaded again. + + + """ + # Deletion order matters. Blobs are deleted in last so that the user can't end + # up in a state where a `ref`` refers to a missing snapshot or a snapshot + # symlink refers to a deleted blob. + + # Delete entire repos + for path in self.repos: + _try_delete_path(path, path_type="repo") + + # Delete snapshot directories + for path in self.snapshots: + _try_delete_path(path, path_type="snapshot") + + # Delete refs files + for path in self.refs: + _try_delete_path(path, path_type="ref") + + # Delete blob files + for path in self.blobs: + _try_delete_path(path, path_type="blob") + + logger.info(f"Cache deletion done. Saved {self.expected_freed_size_str}.") + + +@dataclass(frozen=True) +class HFCacheInfo: + """Frozen data structure holding information about the entire cache-system. + + This data structure is returned by [`scan_cache_dir`] and is immutable. + + Args: + size_on_disk (`int`): + Sum of all valid repo sizes in the cache-system. + repos (`FrozenSet[CachedRepoInfo]`): + Set of [`~CachedRepoInfo`] describing all valid cached repos found on the + cache-system while scanning. + warnings (`List[CorruptedCacheException]`): + List of [`~CorruptedCacheException`] that occurred while scanning the cache. + Those exceptions are captured so that the scan can continue. Corrupted repos + are skipped from the scan. + ++ + Here `size_on_disk` is equal to the sum of all repo sizes (only blobs). However if + some cached repos are corrupted, their sizes are not taken into account. + + + """ + + size_on_disk: int + repos: FrozenSet[CachedRepoInfo] + warnings: List[CorruptedCacheException] + + @property + def size_on_disk_str(self) -> str: + """ + (property) Sum of all valid repo sizes in the cache-system as a human-readable + string. + + Example: "42.2K". + """ + return _format_size(self.size_on_disk) + + def delete_revisions(self, *revisions: str) -> DeleteCacheStrategy: + """Prepare the strategy to delete one or more revisions cached locally. + + Input revisions can be any revision hash. If a revision hash is not found in the + local cache, a warning is thrown but no error is raised. Revisions can be from + different cached repos since hashes are unique across repos, + + Examples: + ```py + >>> from huggingface_hub import scan_cache_dir + >>> cache_info = scan_cache_dir() + >>> delete_strategy = cache_info.delete_revisions( + ... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa" + ... ) + >>> print(f"Will free {delete_strategy.expected_freed_size_str}.") + Will free 7.9K. + >>> delete_strategy.execute() + Cache deletion done. Saved 7.9K. + ``` + + ```py + >>> from huggingface_hub import scan_cache_dir + >>> scan_cache_dir().delete_revisions( + ... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa", + ... "e2983b237dccf3ab4937c97fa717319a9ca1a96d", + ... "6c0e6080953db56375760c0471a8c5f2929baf11", + ... ).execute() + Cache deletion done. Saved 8.6G. + ``` + ++ + `delete_revisions` returns a [`~utils.DeleteCacheStrategy`] object that needs to + be executed. The [`~utils.DeleteCacheStrategy`] is not meant to be modified but + allows having a dry run before actually executing the deletion. + + + """ + hashes_to_delete: Set[str] = set(revisions) + + repos_with_revisions: Dict[CachedRepoInfo, Set[CachedRevisionInfo]] = defaultdict(set) + + for repo in self.repos: + for revision in repo.revisions: + if revision.commit_hash in hashes_to_delete: + repos_with_revisions[repo].add(revision) + hashes_to_delete.remove(revision.commit_hash) + + if len(hashes_to_delete) > 0: + logger.warning(f"Revision(s) not found - cannot delete them: {', '.join(hashes_to_delete)}") + + delete_strategy_blobs: Set[Path] = set() + delete_strategy_refs: Set[Path] = set() + delete_strategy_repos: Set[Path] = set() + delete_strategy_snapshots: Set[Path] = set() + delete_strategy_expected_freed_size = 0 + + for affected_repo, revisions_to_delete in repos_with_revisions.items(): + other_revisions = affected_repo.revisions - revisions_to_delete + + # If no other revisions, it means all revisions are deleted + # -> delete the entire cached repo + if len(other_revisions) == 0: + delete_strategy_repos.add(affected_repo.repo_path) + delete_strategy_expected_freed_size += affected_repo.size_on_disk + continue + + # Some revisions of the repo will be deleted but not all. We need to filter + # which blob files will not be linked anymore. + for revision_to_delete in revisions_to_delete: + # Snapshot dir + delete_strategy_snapshots.add(revision_to_delete.snapshot_path) + + # Refs dir + for ref in revision_to_delete.refs: + delete_strategy_refs.add(affected_repo.repo_path / "refs" / ref) + + # Blobs dir + for file in revision_to_delete.files: + if file.blob_path not in delete_strategy_blobs: + is_file_alone = True + for revision in other_revisions: + for rev_file in revision.files: + if file.blob_path == rev_file.blob_path: + is_file_alone = False + break + if not is_file_alone: + break + + # Blob file not referenced by remaining revisions -> delete + if is_file_alone: + delete_strategy_blobs.add(file.blob_path) + delete_strategy_expected_freed_size += file.size_on_disk + + # Return the strategy instead of executing it. + return DeleteCacheStrategy( + blobs=frozenset(delete_strategy_blobs), + refs=frozenset(delete_strategy_refs), + repos=frozenset(delete_strategy_repos), + snapshots=frozenset(delete_strategy_snapshots), + expected_freed_size=delete_strategy_expected_freed_size, + ) + + def export_as_table(self, *, verbosity: int = 0) -> str: + """Generate a table from the [`HFCacheInfo`] object. + + Pass `verbosity=0` to get a table with a single row per repo, with columns + "repo_id", "repo_type", "size_on_disk", "nb_files", "last_accessed", "last_modified", "refs", "local_path". + + Pass `verbosity=1` to get a table with a row per repo and revision (thus multiple rows can appear for a single repo), with columns + "repo_id", "repo_type", "revision", "size_on_disk", "nb_files", "last_modified", "refs", "local_path". + + Example: + ```py + >>> from huggingface_hub.utils import scan_cache_dir + + >>> hf_cache_info = scan_cache_dir() + HFCacheInfo(...) + + >>> print(hf_cache_info.export_as_table()) + REPO ID REPO TYPE SIZE ON DISK NB FILES LAST_ACCESSED LAST_MODIFIED REFS LOCAL PATH + --------------------------------------------------- --------- ------------ -------- ------------- ------------- ---- -------------------------------------------------------------------------------------------------- + roberta-base model 2.7M 5 1 day ago 1 week ago main ~/.cache/huggingface/hub/models--roberta-base + suno/bark model 8.8K 1 1 week ago 1 week ago main ~/.cache/huggingface/hub/models--suno--bark + t5-base model 893.8M 4 4 days ago 7 months ago main ~/.cache/huggingface/hub/models--t5-base + t5-large model 3.0G 4 5 weeks ago 5 months ago main ~/.cache/huggingface/hub/models--t5-large + + >>> print(hf_cache_info.export_as_table(verbosity=1)) + REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST_MODIFIED REFS LOCAL PATH + --------------------------------------------------- --------- ---------------------------------------- ------------ -------- ------------- ---- ----------------------------------------------------------------------------------------------------------------------------------------------------- + roberta-base model e2da8e2f811d1448a5b465c236feacd80ffbac7b 2.7M 5 1 week ago main ~/.cache/huggingface/hub/models--roberta-base/snapshots/e2da8e2f811d1448a5b465c236feacd80ffbac7b + suno/bark model 70a8a7d34168586dc5d028fa9666aceade177992 8.8K 1 1 week ago main ~/.cache/huggingface/hub/models--suno--bark/snapshots/70a8a7d34168586dc5d028fa9666aceade177992 + t5-base model a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 893.8M 4 7 months ago main ~/.cache/huggingface/hub/models--t5-base/snapshots/a9723ea7f1b39c1eae772870f3b547bf6ef7e6c1 + t5-large model 150ebc2c4b72291e770f58e6057481c8d2ed331a 3.0G 4 5 months ago main ~/.cache/huggingface/hub/models--t5-large/snapshots/150ebc2c4b72291e770f58e6057481c8d2ed331a + ``` + + Args: + verbosity (`int`, *optional*): + The verbosity level. Defaults to 0. + + Returns: + `str`: The table as a string. + """ + if verbosity == 0: + return tabulate( + rows=[ + [ + repo.repo_id, + repo.repo_type, + "{:>12}".format(repo.size_on_disk_str), + repo.nb_files, + repo.last_accessed_str, + repo.last_modified_str, + ", ".join(sorted(repo.refs)), + str(repo.repo_path), + ] + for repo in sorted(self.repos, key=lambda repo: repo.repo_path) + ], + headers=[ + "REPO ID", + "REPO TYPE", + "SIZE ON DISK", + "NB FILES", + "LAST_ACCESSED", + "LAST_MODIFIED", + "REFS", + "LOCAL PATH", + ], + ) + else: + return tabulate( + rows=[ + [ + repo.repo_id, + repo.repo_type, + revision.commit_hash, + "{:>12}".format(revision.size_on_disk_str), + revision.nb_files, + revision.last_modified_str, + ", ".join(sorted(revision.refs)), + str(revision.snapshot_path), + ] + for repo in sorted(self.repos, key=lambda repo: repo.repo_path) + for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash) + ], + headers=[ + "REPO ID", + "REPO TYPE", + "REVISION", + "SIZE ON DISK", + "NB FILES", + "LAST_MODIFIED", + "REFS", + "LOCAL PATH", + ], + ) + + +def scan_cache_dir(cache_dir: Optional[Union[str, Path]] = None) -> HFCacheInfo: + """Scan the entire HF cache-system and return a [`~HFCacheInfo`] structure. + + Use `scan_cache_dir` in order to programmatically scan your cache-system. The cache + will be scanned repo by repo. If a repo is corrupted, a [`~CorruptedCacheException`] + will be thrown internally but captured and returned in the [`~HFCacheInfo`] + structure. Only valid repos get a proper report. + + ```py + >>> from huggingface_hub import scan_cache_dir + + >>> hf_cache_info = scan_cache_dir() + HFCacheInfo( + size_on_disk=3398085269, + repos=frozenset({ + CachedRepoInfo( + repo_id='t5-small', + repo_type='model', + repo_path=PosixPath(...), + size_on_disk=970726914, + nb_files=11, + revisions=frozenset({ + CachedRevisionInfo( + commit_hash='d78aea13fa7ecd06c29e3e46195d6341255065d5', + size_on_disk=970726339, + snapshot_path=PosixPath(...), + files=frozenset({ + CachedFileInfo( + file_name='config.json', + size_on_disk=1197 + file_path=PosixPath(...), + blob_path=PosixPath(...), + ), + CachedFileInfo(...), + ... + }), + ), + CachedRevisionInfo(...), + ... + }), + ), + CachedRepoInfo(...), + ... + }), + warnings=[ + CorruptedCacheException("Snapshots dir doesn't exist in cached repo: ..."), + CorruptedCacheException(...), + ... + ], + ) + ``` + + You can also print a detailed report directly from the `huggingface-cli` using: + ```text + > huggingface-cli scan-cache + REPO ID REPO TYPE SIZE ON DISK NB FILES REFS LOCAL PATH + --------------------------- --------- ------------ -------- ------------------- ------------------------------------------------------------------------- + glue dataset 116.3K 15 1.17.0, main, 2.4.0 /Users/lucain/.cache/huggingface/hub/datasets--glue + google/fleurs dataset 64.9M 6 main, refs/pr/1 /Users/lucain/.cache/huggingface/hub/datasets--google--fleurs + Jean-Baptiste/camembert-ner model 441.0M 7 main /Users/lucain/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner + bert-base-cased model 1.9G 13 main /Users/lucain/.cache/huggingface/hub/models--bert-base-cased + t5-base model 10.1K 3 main /Users/lucain/.cache/huggingface/hub/models--t5-base + t5-small model 970.7M 11 refs/pr/1, main /Users/lucain/.cache/huggingface/hub/models--t5-small + + Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G. + Got 1 warning(s) while scanning. Use -vvv to print details. + ``` + + Args: + cache_dir (`str` or `Path`, `optional`): + Cache directory to cache. Defaults to the default HF cache directory. + ++ + Raises: + + `CacheNotFound` + If the cache directory does not exist. + + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If the cache directory is a file, instead of a directory. + + + + Returns: a [`~HFCacheInfo`] object. + """ + if cache_dir is None: + cache_dir = HF_HUB_CACHE + + cache_dir = Path(cache_dir).expanduser().resolve() + if not cache_dir.exists(): + raise CacheNotFound( + f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.", + cache_dir=cache_dir, + ) + + if cache_dir.is_file(): + raise ValueError( + f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable." + ) + + repos: Set[CachedRepoInfo] = set() + warnings: List[CorruptedCacheException] = [] + for repo_path in cache_dir.iterdir(): + if repo_path.name == ".locks": # skip './.locks/' folder + continue + try: + repos.add(_scan_cached_repo(repo_path)) + except CorruptedCacheException as e: + warnings.append(e) + + return HFCacheInfo( + repos=frozenset(repos), + size_on_disk=sum(repo.size_on_disk for repo in repos), + warnings=warnings, + ) + + +def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: + """Scan a single cache repo and return information about it. + + Any unexpected behavior will raise a [`~CorruptedCacheException`]. + """ + if not repo_path.is_dir(): + raise CorruptedCacheException(f"Repo path is not a directory: {repo_path}") + + if "--" not in repo_path.name: + raise CorruptedCacheException(f"Repo path is not a valid HuggingFace cache directory: {repo_path}") + + repo_type, repo_id = repo_path.name.split("--", maxsplit=1) + repo_type = repo_type[:-1] # "models" -> "model" + repo_id = repo_id.replace("--", "/") # google/fleurs -> "google/fleurs" + + if repo_type not in {"dataset", "model", "space"}: + raise CorruptedCacheException( + f"Repo type must be `dataset`, `model` or `space`, found `{repo_type}` ({repo_path})." + ) + + blob_stats: Dict[Path, os.stat_result] = {} # Key is blob_path, value is blob stats + + snapshots_path = repo_path / "snapshots" + refs_path = repo_path / "refs" + + if not snapshots_path.exists() or not snapshots_path.is_dir(): + raise CorruptedCacheException(f"Snapshots dir doesn't exist in cached repo: {snapshots_path}") + + # Scan over `refs` directory + + # key is revision hash, value is set of refs + refs_by_hash: Dict[str, Set[str]] = defaultdict(set) + if refs_path.exists(): + # Example of `refs` directory + # ββ refs + # βββ main + # βββ refs + # βββ pr + # βββ 1 + if refs_path.is_file(): + raise CorruptedCacheException(f"Refs directory cannot be a file: {refs_path}") + + for ref_path in refs_path.glob("**/*"): + # glob("**/*") iterates over all files and directories -> skip directories + if ref_path.is_dir(): + continue + + ref_name = str(ref_path.relative_to(refs_path)) + with ref_path.open() as f: + commit_hash = f.read() + + refs_by_hash[commit_hash].add(ref_name) + + # Scan snapshots directory + cached_revisions: Set[CachedRevisionInfo] = set() + for revision_path in snapshots_path.iterdir(): + # Ignore OS-created helper files + if revision_path.name in FILES_TO_IGNORE: + continue + if revision_path.is_file(): + raise CorruptedCacheException(f"Snapshots folder corrupted. Found a file: {revision_path}") + + cached_files = set() + for file_path in revision_path.glob("**/*"): + # glob("**/*") iterates over all files and directories -> skip directories + if file_path.is_dir(): + continue + + blob_path = Path(file_path).resolve() + if not blob_path.exists(): + raise CorruptedCacheException(f"Blob missing (broken symlink): {blob_path}") + + if blob_path not in blob_stats: + blob_stats[blob_path] = blob_path.stat() + + cached_files.add( + CachedFileInfo( + file_name=file_path.name, + file_path=file_path, + size_on_disk=blob_stats[blob_path].st_size, + blob_path=blob_path, + blob_last_accessed=blob_stats[blob_path].st_atime, + blob_last_modified=blob_stats[blob_path].st_mtime, + ) + ) + + # Last modified is either the last modified blob file or the revision folder + # itself if it is empty + if len(cached_files) > 0: + revision_last_modified = max(blob_stats[file.blob_path].st_mtime for file in cached_files) + else: + revision_last_modified = revision_path.stat().st_mtime + + cached_revisions.add( + CachedRevisionInfo( + commit_hash=revision_path.name, + files=frozenset(cached_files), + refs=frozenset(refs_by_hash.pop(revision_path.name, set())), + size_on_disk=sum( + blob_stats[blob_path].st_size for blob_path in set(file.blob_path for file in cached_files) + ), + snapshot_path=revision_path, + last_modified=revision_last_modified, + ) + ) + + # Check that all refs referred to an existing revision + if len(refs_by_hash) > 0: + raise CorruptedCacheException( + f"Reference(s) refer to missing commit hashes: {dict(refs_by_hash)} ({repo_path})." + ) + + # Last modified is either the last modified blob file or the repo folder itself if + # no blob files has been found. Same for last accessed. + if len(blob_stats) > 0: + repo_last_accessed = max(stat.st_atime for stat in blob_stats.values()) + repo_last_modified = max(stat.st_mtime for stat in blob_stats.values()) + else: + repo_stats = repo_path.stat() + repo_last_accessed = repo_stats.st_atime + repo_last_modified = repo_stats.st_mtime + + # Build and return frozen structure + return CachedRepoInfo( + nb_files=len(blob_stats), + repo_id=repo_id, + repo_path=repo_path, + repo_type=repo_type, # type: ignore + revisions=frozenset(cached_revisions), + size_on_disk=sum(stat.st_size for stat in blob_stats.values()), + last_accessed=repo_last_accessed, + last_modified=repo_last_modified, + ) + + +def _format_size(num: int) -> str: + """Format size in bytes into a human-readable string. + + Taken from https://stackoverflow.com/a/1094933 + """ + num_f = float(num) + for unit in ["", "K", "M", "G", "T", "P", "E", "Z"]: + if abs(num_f) < 1000.0: + return f"{num_f:3.1f}{unit}" + num_f /= 1000.0 + return f"{num_f:.1f}Y" + + +_TIMESINCE_CHUNKS = ( + # Label, divider, max value + ("second", 1, 60), + ("minute", 60, 60), + ("hour", 60 * 60, 24), + ("day", 60 * 60 * 24, 6), + ("week", 60 * 60 * 24 * 7, 6), + ("month", 60 * 60 * 24 * 30, 11), + ("year", 60 * 60 * 24 * 365, None), +) + + +def _format_timesince(ts: float) -> str: + """Format timestamp in seconds into a human-readable string, relative to now. + + Vaguely inspired by Django's `timesince` formatter. + """ + delta = time.time() - ts + if delta < 20: + return "a few seconds ago" + for label, divider, max_value in _TIMESINCE_CHUNKS: # noqa: B007 + value = round(delta / divider) + if max_value is not None and value <= max_value: + break + return f"{value} {label}{'s' if value > 1 else ''} ago" + + +def _try_delete_path(path: Path, path_type: str) -> None: + """Try to delete a local file or folder. + + If the path does not exists, error is logged as a warning and then ignored. + + Args: + path (`Path`) + Path to delete. Can be a file or a folder. + path_type (`str`) + What path are we deleting ? Only for logging purposes. Example: "snapshot". + """ + logger.info(f"Delete {path_type}: {path}") + try: + if path.is_file(): + os.remove(path) + else: + shutil.rmtree(path) + except FileNotFoundError: + logger.warning(f"Couldn't delete {path_type}: file not found ({path})", exc_info=True) + except PermissionError: + logger.warning(f"Couldn't delete {path_type}: permission denied ({path})", exc_info=True) diff --git a/huggingface_hub/utils/_chunk_utils.py b/huggingface_hub/utils/_chunk_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b0af032ae6a68f03676ad7fdb8e483248d9853f8 --- /dev/null +++ b/huggingface_hub/utils/_chunk_utils.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains a utility to iterate by chunks over an iterator.""" + +import itertools +from typing import Iterable, TypeVar + + +T = TypeVar("T") + + +def chunk_iterable(iterable: Iterable[T], chunk_size: int) -> Iterable[Iterable[T]]: + """Iterates over an iterator chunk by chunk. + + Taken from https://stackoverflow.com/a/8998040. + See also https://github.com/huggingface/huggingface_hub/pull/920#discussion_r938793088. + + Args: + iterable (`Iterable`): + The iterable on which we want to iterate. + chunk_size (`int`): + Size of the chunks. Must be a strictly positive integer (e.g. >0). + + Example: + + ```python + >>> from huggingface_hub.utils import chunk_iterable + + >>> for items in chunk_iterable(range(17), chunk_size=8): + ... print(items) + # [0, 1, 2, 3, 4, 5, 6, 7] + # [8, 9, 10, 11, 12, 13, 14, 15] + # [16] # smaller last chunk + ``` + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If `chunk_size` <= 0. + ++ The last chunk can be smaller than `chunk_size`. + + """ + if not isinstance(chunk_size, int) or chunk_size <= 0: + raise ValueError("`chunk_size` must be a strictly positive integer (>0).") + + iterator = iter(iterable) + while True: + try: + next_item = next(iterator) + except StopIteration: + return + yield itertools.chain((next_item,), itertools.islice(iterator, chunk_size - 1)) diff --git a/huggingface_hub/utils/_datetime.py b/huggingface_hub/utils/_datetime.py new file mode 100644 index 0000000000000000000000000000000000000000..e544884b8793d8d409303cafd34586523fc3fb1c --- /dev/null +++ b/huggingface_hub/utils/_datetime.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle datetimes in Huggingface Hub.""" + +from datetime import datetime, timezone + + +def parse_datetime(date_string: str) -> datetime: + """ + Parses a date_string returned from the server to a datetime object. + + This parser is a weak-parser is the sense that it handles only a single format of + date_string. It is expected that the server format will never change. The + implementation depends only on the standard lib to avoid an external dependency + (python-dateutil). See full discussion about this decision on PR: + https://github.com/huggingface/huggingface_hub/pull/999. + + Example: + ```py + > parse_datetime('2022-08-19T07:19:38.123Z') + datetime.datetime(2022, 8, 19, 7, 19, 38, 123000, tzinfo=timezone.utc) + ``` + + Args: + date_string (`str`): + A string representing a datetime returned by the Hub server. + String is expected to follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern. + + Returns: + A python datetime object. + + Raises: + :class:`ValueError`: + If `date_string` cannot be parsed. + """ + try: + # Datetime ending with a Z means "UTC". We parse the date and then explicitly + # set the timezone to UTC. + # See https://en.wikipedia.org/wiki/ISO_8601#Coordinated_Universal_Time_(UTC) + # Taken from https://stackoverflow.com/a/3168394. + if len(date_string) == 30: + # Means timezoned-timestamp with nanoseconds precision. We need to truncate the last 3 digits. + date_string = date_string[:-4] + "Z" + dt = datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ") + return dt.replace(tzinfo=timezone.utc) # Set explicit timezone + except ValueError as e: + raise ValueError( + f"Cannot parse '{date_string}' as a datetime. Date string is expected to" + " follow '%Y-%m-%dT%H:%M:%S.%fZ' pattern." + ) from e diff --git a/huggingface_hub/utils/_deprecation.py b/huggingface_hub/utils/_deprecation.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb8d6e418c76accd1ecd61158b4bdd265e12f71 --- /dev/null +++ b/huggingface_hub/utils/_deprecation.py @@ -0,0 +1,136 @@ +import warnings +from functools import wraps +from inspect import Parameter, signature +from typing import Iterable, Optional + + +def _deprecate_positional_args(*, version: str): + """Decorator for methods that issues warnings for positional arguments. + Using the keyword-only argument syntax in pep 3102, arguments after the + * will issue a warning when passed as a positional argument. + + Args: + version (`str`): + The version when positional arguments will result in error. + """ + + def _inner_deprecate_positional_args(f): + sig = signature(f) + kwonly_args = [] + all_args = [] + for name, param in sig.parameters.items(): + if param.kind == Parameter.POSITIONAL_OR_KEYWORD: + all_args.append(name) + elif param.kind == Parameter.KEYWORD_ONLY: + kwonly_args.append(name) + + @wraps(f) + def inner_f(*args, **kwargs): + extra_args = len(args) - len(all_args) + if extra_args <= 0: + return f(*args, **kwargs) + # extra_args > 0 + args_msg = [ + f"{name}='{arg}'" if isinstance(arg, str) else f"{name}={arg}" + for name, arg in zip(kwonly_args[:extra_args], args[-extra_args:]) + ] + args_msg = ", ".join(args_msg) + warnings.warn( + f"Deprecated positional argument(s) used in '{f.__name__}': pass" + f" {args_msg} as keyword args. From version {version} passing these" + " as positional arguments will result in an error,", + FutureWarning, + ) + kwargs.update(zip(sig.parameters, args)) + return f(**kwargs) + + return inner_f + + return _inner_deprecate_positional_args + + +def _deprecate_arguments( + *, + version: str, + deprecated_args: Iterable[str], + custom_message: Optional[str] = None, +): + """Decorator to issue warnings when using deprecated arguments. + + TODO: could be useful to be able to set a custom error message. + + Args: + version (`str`): + The version when deprecated arguments will result in error. + deprecated_args (`List[str]`): + List of the arguments to be deprecated. + custom_message (`str`, *optional*): + Warning message that is raised. If not passed, a default warning message + will be created. + """ + + def _inner_deprecate_positional_args(f): + sig = signature(f) + + @wraps(f) + def inner_f(*args, **kwargs): + # Check for used deprecated arguments + used_deprecated_args = [] + for _, parameter in zip(args, sig.parameters.values()): + if parameter.name in deprecated_args: + used_deprecated_args.append(parameter.name) + for kwarg_name, kwarg_value in kwargs.items(): + if ( + # If argument is deprecated but still used + kwarg_name in deprecated_args + # And then the value is not the default value + and kwarg_value != sig.parameters[kwarg_name].default + ): + used_deprecated_args.append(kwarg_name) + + # Warn and proceed + if len(used_deprecated_args) > 0: + message = ( + f"Deprecated argument(s) used in '{f.__name__}':" + f" {', '.join(used_deprecated_args)}. Will not be supported from" + f" version '{version}'." + ) + if custom_message is not None: + message += "\n\n" + custom_message + warnings.warn(message, FutureWarning) + return f(*args, **kwargs) + + return inner_f + + return _inner_deprecate_positional_args + + +def _deprecate_method(*, version: str, message: Optional[str] = None): + """Decorator to issue warnings when using a deprecated method. + + Args: + version (`str`): + The version when deprecated arguments will result in error. + message (`str`, *optional*): + Warning message that is raised. If not passed, a default warning message + will be created. + """ + + def _inner_deprecate_method(f): + name = f.__name__ + if name == "__init__": + name = f.__qualname__.split(".")[0] # class name instead of method name + + @wraps(f) + def inner_f(*args, **kwargs): + warning_message = ( + f"'{name}' (from '{f.__module__}') is deprecated and will be removed from version '{version}'." + ) + if message is not None: + warning_message += " " + message + warnings.warn(warning_message, FutureWarning) + return f(*args, **kwargs) + + return inner_f + + return _inner_deprecate_method diff --git a/huggingface_hub/utils/_experimental.py b/huggingface_hub/utils/_experimental.py new file mode 100644 index 0000000000000000000000000000000000000000..34141eba09123c06fbca55c929a19a0264e5788e --- /dev/null +++ b/huggingface_hub/utils/_experimental.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to flag a feature as "experimental" in Huggingface Hub.""" + +import warnings +from functools import wraps +from typing import Callable + +from .. import constants + + +def experimental(fn: Callable) -> Callable: + """Decorator to flag a feature as experimental. + + An experimental feature trigger a warning when used as it might be subject to breaking changes in the future. + Warnings can be disabled by setting the environment variable `HF_EXPERIMENTAL_WARNING` to `0`. + + Args: + fn (`Callable`): + The function to flag as experimental. + + Returns: + `Callable`: The decorated function. + + Example: + + ```python + >>> from huggingface_hub.utils import experimental + + >>> @experimental + ... def my_function(): + ... print("Hello world!") + + >>> my_function() + UserWarning: 'my_function' is experimental and might be subject to breaking changes in the future. You can disable + this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment variable. + Hello world! + ``` + """ + # For classes, put the "experimental" around the "__new__" method => __new__ will be removed in warning message + name = fn.__qualname__[: -len(".__new__")] if fn.__qualname__.endswith(".__new__") else fn.__qualname__ + + @wraps(fn) + def _inner_fn(*args, **kwargs): + if not constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING: + warnings.warn( + f"'{name}' is experimental and might be subject to breaking changes in the future." + " You can disable this warning by setting `HF_HUB_DISABLE_EXPERIMENTAL_WARNING=1` as environment" + " variable.", + UserWarning, + ) + return fn(*args, **kwargs) + + return _inner_fn diff --git a/huggingface_hub/utils/_fixes.py b/huggingface_hub/utils/_fixes.py new file mode 100644 index 0000000000000000000000000000000000000000..259c976c40f677514fce0ed42b26a6b6d838d07d --- /dev/null +++ b/huggingface_hub/utils/_fixes.py @@ -0,0 +1,121 @@ +# JSONDecodeError was introduced in requests=2.27 released in 2022. +# This allows us to support older requests for users +# More information: https://github.com/psf/requests/pull/5856 +try: + from requests import JSONDecodeError # type: ignore # noqa: F401 +except ImportError: + try: + from simplejson import JSONDecodeError # type: ignore # noqa: F401 + except ImportError: + from json import JSONDecodeError # type: ignore # noqa: F401 +import contextlib +import os +import shutil +import stat +import tempfile +from functools import partial +from pathlib import Path +from typing import Callable, Generator, Optional, Union + +import yaml +from filelock import BaseFileLock, FileLock, SoftFileLock, Timeout + +from .. import constants +from . import logging + + +logger = logging.get_logger(__name__) + +# Wrap `yaml.dump` to set `allow_unicode=True` by default. +# +# Example: +# ```py +# >>> yaml.dump({"emoji": "π", "some unicode": "ζ₯ζ¬γ"}) +# 'emoji: "\\U0001F440"\nsome unicode: "\\u65E5\\u672C\\u304B"\n' +# +# >>> yaml_dump({"emoji": "π", "some unicode": "ζ₯ζ¬γ"}) +# 'emoji: "π"\nsome unicode: "ζ₯ζ¬γ"\n' +# ``` +yaml_dump: Callable[..., str] = partial(yaml.dump, stream=None, allow_unicode=True) # type: ignore + + +@contextlib.contextmanager +def SoftTemporaryDirectory( + suffix: Optional[str] = None, + prefix: Optional[str] = None, + dir: Optional[Union[Path, str]] = None, + **kwargs, +) -> Generator[Path, None, None]: + """ + Context manager to create a temporary directory and safely delete it. + + If tmp directory cannot be deleted normally, we set the WRITE permission and retry. + If cleanup still fails, we give up but don't raise an exception. This is equivalent + to `tempfile.TemporaryDirectory(..., ignore_cleanup_errors=True)` introduced in + Python 3.10. + + See https://www.scivision.dev/python-tempfile-permission-error-windows/. + """ + tmpdir = tempfile.TemporaryDirectory(prefix=prefix, suffix=suffix, dir=dir, **kwargs) + yield Path(tmpdir.name).resolve() + + try: + # First once with normal cleanup + shutil.rmtree(tmpdir.name) + except Exception: + # If failed, try to set write permission and retry + try: + shutil.rmtree(tmpdir.name, onerror=_set_write_permission_and_retry) + except Exception: + pass + + # And finally, cleanup the tmpdir. + # If it fails again, give up but do not throw error + try: + tmpdir.cleanup() + except Exception: + pass + + +def _set_write_permission_and_retry(func, path, excinfo): + os.chmod(path, stat.S_IWRITE) + func(path) + + +@contextlib.contextmanager +def WeakFileLock(lock_file: Union[str, Path]) -> Generator[BaseFileLock, None, None]: + """A filelock with some custom logic. + + This filelock is weaker than the default filelock in that: + 1. It won't raise an exception if release fails. + 2. It will default to a SoftFileLock if the filesystem does not support flock. + + An INFO log message is emitted every 10 seconds if the lock is not acquired immediately. + """ + lock = FileLock(lock_file, timeout=constants.FILELOCK_LOG_EVERY_SECONDS) + while True: + try: + lock.acquire() + except Timeout: + logger.info("still waiting to acquire lock on %s", lock_file) + except NotImplementedError as e: + if "use SoftFileLock instead" in str(e): + # It's possible that the system does support flock, expect for one partition or filesystem. + # In this case, let's default to a SoftFileLock. + logger.warning( + "FileSystem does not appear to support flock. Falling back to SoftFileLock for %s", lock_file + ) + lock = SoftFileLock(lock_file, timeout=constants.FILELOCK_LOG_EVERY_SECONDS) + continue + else: + break + + yield lock + + try: + return lock.release() + except OSError: + try: + Path(lock_file).unlink() + except OSError: + pass diff --git a/huggingface_hub/utils/_git_credential.py b/huggingface_hub/utils/_git_credential.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ed77f4e49ca88ff4fa9aba48cbf00195036013 --- /dev/null +++ b/huggingface_hub/utils/_git_credential.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to manage Git credentials.""" + +import re +import subprocess +from typing import List, Optional + +from ..constants import ENDPOINT +from ._subprocess import run_interactive_subprocess, run_subprocess + + +GIT_CREDENTIAL_REGEX = re.compile( + r""" + ^\s* # start of line + credential\.helper # credential.helper value + \s*=\s* # separator + (\w+) # the helper name (group 1) + (\s|$) # whitespace or end of line + """, + flags=re.MULTILINE | re.IGNORECASE | re.VERBOSE, +) + + +def list_credential_helpers(folder: Optional[str] = None) -> List[str]: + """Return the list of git credential helpers configured. + + See https://git-scm.com/docs/gitcredentials. + + Credentials are saved in all configured helpers (store, cache, macOS keychain,...). + Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential. + + Args: + folder (`str`, *optional*): + The folder in which to check the configured helpers. + """ + try: + output = run_subprocess("git config --list", folder=folder).stdout + parsed = _parse_credential_output(output) + return parsed + except subprocess.CalledProcessError as exc: + raise EnvironmentError(exc.stderr) + + +def set_git_credential(token: str, username: str = "hf_user", folder: Optional[str] = None) -> None: + """Save a username/token pair in git credential for HF Hub registry. + + Credentials are saved in all configured helpers (store, cache, macOS keychain,...). + Calls "`git credential approve`" internally. See https://git-scm.com/docs/git-credential. + + Args: + username (`str`, defaults to `"hf_user"`): + A git username. Defaults to `"hf_user"`, the default user used in the Hub. + token (`str`, defaults to `"hf_user"`): + A git password. In practice, the User Access Token for the Hub. + See https://huggingface.co/settings/tokens. + folder (`str`, *optional*): + The folder in which to check the configured helpers. + """ + with run_interactive_subprocess("git credential approve", folder=folder) as ( + stdin, + _, + ): + stdin.write(f"url={ENDPOINT}\nusername={username.lower()}\npassword={token}\n\n") + stdin.flush() + + +def unset_git_credential(username: str = "hf_user", folder: Optional[str] = None) -> None: + """Erase credentials from git credential for HF Hub registry. + + Credentials are erased from the configured helpers (store, cache, macOS + keychain,...), if any. If `username` is not provided, any credential configured for + HF Hub endpoint is erased. + Calls "`git credential erase`" internally. See https://git-scm.com/docs/git-credential. + + Args: + username (`str`, defaults to `"hf_user"`): + A git username. Defaults to `"hf_user"`, the default user used in the Hub. + folder (`str`, *optional*): + The folder in which to check the configured helpers. + """ + with run_interactive_subprocess("git credential reject", folder=folder) as ( + stdin, + _, + ): + standard_input = f"url={ENDPOINT}\n" + if username is not None: + standard_input += f"username={username.lower()}\n" + standard_input += "\n" + + stdin.write(standard_input) + stdin.flush() + + +def _parse_credential_output(output: str) -> List[str]: + """Parse the output of `git credential fill` to extract the password. + + Args: + output (`str`): + The output of `git credential fill`. + """ + # NOTE: If user has set an helper for a custom URL, it will not we caught here. + # Example: `credential.https://huggingface.co.helper=store` + # See: https://github.com/huggingface/huggingface_hub/pull/1138#discussion_r1013324508 + return sorted( # Sort for nice printing + set( # Might have some duplicates + match[0] for match in GIT_CREDENTIAL_REGEX.findall(output) + ) + ) diff --git a/huggingface_hub/utils/_headers.py b/huggingface_hub/utils/_headers.py new file mode 100644 index 0000000000000000000000000000000000000000..e76afb6ceab094a06ca06e41e406a9236c92e8a3 --- /dev/null +++ b/huggingface_hub/utils/_headers.py @@ -0,0 +1,239 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle headers to send in calls to Huggingface Hub.""" + +from typing import Dict, Optional, Union + +from huggingface_hub.errors import LocalTokenNotFoundError + +from .. import constants +from ._runtime import ( + get_fastai_version, + get_fastcore_version, + get_hf_hub_version, + get_python_version, + get_tf_version, + get_torch_version, + is_fastai_available, + is_fastcore_available, + is_tf_available, + is_torch_available, +) +from ._token import get_token +from ._validators import validate_hf_hub_args + + +@validate_hf_hub_args +def build_hf_headers( + *, + token: Optional[Union[bool, str]] = None, + is_write_action: bool = False, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, + headers: Optional[Dict[str, str]] = None, +) -> Dict[str, str]: + """ + Build headers dictionary to send in a HF Hub call. + + By default, authorization token is always provided either from argument (explicit + use) or retrieved from the cache (implicit use). To explicitly avoid sending the + token to the Hub, set `token=False` or set the `HF_HUB_DISABLE_IMPLICIT_TOKEN` + environment variable. + + In case of an API call that requires write access, an error is thrown if token is + `None` or token is an organization token (starting with `"api_org***"`). + + In addition to the auth header, a user-agent is added to provide information about + the installed packages (versions of python, huggingface_hub, torch, tensorflow, + fastai and fastcore). + + Args: + token (`str`, `bool`, *optional*): + The token to be sent in authorization header for the Hub call: + - if a string, it is used as the Hugging Face token + - if `True`, the token is read from the machine (cache or env variable) + - if `False`, authorization header is not set + - if `None`, the token is read from the machine only except if + `HF_HUB_DISABLE_IMPLICIT_TOKEN` env variable is set. + is_write_action (`bool`, default to `False`): + Set to True if the API call requires a write access. If `True`, the token + will be validated (cannot be `None`, cannot start by `"api_org***"`). + library_name (`str`, *optional*): + The name of the library that is making the HTTP request. Will be added to + the user-agent header. + library_version (`str`, *optional*): + The version of the library that is making the HTTP request. Will be added + to the user-agent header. + user_agent (`str`, `dict`, *optional*): + The user agent info in the form of a dictionary or a single string. It will + be completed with information about the installed packages. + headers (`dict`, *optional*): + Additional headers to include in the request. Those headers take precedence + over the ones generated by this function. + + Returns: + A `Dict` of headers to pass in your API call. + + Example: + ```py + >>> build_hf_headers(token="hf_***") # explicit token + {"authorization": "Bearer hf_***", "user-agent": ""} + + >>> build_hf_headers(token=True) # explicitly use cached token + {"authorization": "Bearer hf_***",...} + + >>> build_hf_headers(token=False) # explicitly don't use cached token + {"user-agent": ...} + + >>> build_hf_headers() # implicit use of the cached token + {"authorization": "Bearer hf_***",...} + + # HF_HUB_DISABLE_IMPLICIT_TOKEN=True # to set as env variable + >>> build_hf_headers() # token is not sent + {"user-agent": ...} + + >>> build_hf_headers(token="api_org_***", is_write_action=True) + ValueError: You must use your personal account token for write-access methods. + + >>> build_hf_headers(library_name="transformers", library_version="1.2.3") + {"authorization": ..., "user-agent": "transformers/1.2.3; hf_hub/0.10.2; python/3.10.4; tensorflow/1.55"} + ``` + + Raises: + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If organization token is passed and "write" access is required. + [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + If "write" access is required but token is not passed and not saved locally. + [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + If `token=True` but token is not saved locally. + """ + # Get auth token to send + token_to_send = get_token_to_send(token) + _validate_token_to_send(token_to_send, is_write_action=is_write_action) + + # Combine headers + hf_headers = { + "user-agent": _http_user_agent( + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + ) + } + if token_to_send is not None: + hf_headers["authorization"] = f"Bearer {token_to_send}" + if headers is not None: + hf_headers.update(headers) + return hf_headers + + +def get_token_to_send(token: Optional[Union[bool, str]]) -> Optional[str]: + """Select the token to send from either `token` or the cache.""" + # Case token is explicitly provided + if isinstance(token, str): + return token + + # Case token is explicitly forbidden + if token is False: + return None + + # Token is not provided: we get it from local cache + cached_token = get_token() + + # Case token is explicitly required + if token is True: + if cached_token is None: + raise LocalTokenNotFoundError( + "Token is required (`token=True`), but no token found. You" + " need to provide a token or be logged in to Hugging Face with" + " `huggingface-cli login` or `huggingface_hub.login`. See" + " https://huggingface.co/settings/tokens." + ) + return cached_token + + # Case implicit use of the token is forbidden by env variable + if constants.HF_HUB_DISABLE_IMPLICIT_TOKEN: + return None + + # Otherwise: we use the cached token as the user has not explicitly forbidden it + return cached_token + + +def _validate_token_to_send(token: Optional[str], is_write_action: bool) -> None: + if is_write_action: + if token is None: + raise ValueError( + "Token is required (write-access action) but no token found. You need" + " to provide a token or be logged in to Hugging Face with" + " `huggingface-cli login` or `huggingface_hub.login`. See" + " https://huggingface.co/settings/tokens." + ) + if token.startswith("api_org"): + raise ValueError( + "You must use your personal account token for write-access methods. To" + " generate a write-access token, go to" + " https://huggingface.co/settings/tokens" + ) + + +def _http_user_agent( + *, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, +) -> str: + """Format a user-agent string containing information about the installed packages. + + Args: + library_name (`str`, *optional*): + The name of the library that is making the HTTP request. + library_version (`str`, *optional*): + The version of the library that is making the HTTP request. + user_agent (`str`, `dict`, *optional*): + The user agent info in the form of a dictionary or a single string. + + Returns: + The formatted user-agent string. + """ + if library_name is not None: + ua = f"{library_name}/{library_version}" + else: + ua = "unknown/None" + ua += f"; hf_hub/{get_hf_hub_version()}" + ua += f"; python/{get_python_version()}" + + if not constants.HF_HUB_DISABLE_TELEMETRY: + if is_torch_available(): + ua += f"; torch/{get_torch_version()}" + if is_tf_available(): + ua += f"; tensorflow/{get_tf_version()}" + if is_fastai_available(): + ua += f"; fastai/{get_fastai_version()}" + if is_fastcore_available(): + ua += f"; fastcore/{get_fastcore_version()}" + + if isinstance(user_agent, dict): + ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += "; " + user_agent + + return _deduplicate_user_agent(ua) + + +def _deduplicate_user_agent(user_agent: str) -> str: + """Deduplicate redundant information in the generated user-agent.""" + # Split around ";" > Strip whitespaces > Store as dict keys (ensure unicity) > format back as string + # Order is implicitly preserved by dictionary structure (see https://stackoverflow.com/a/53657523). + return "; ".join({key.strip(): None for key in user_agent.split(";")}.keys()) diff --git a/huggingface_hub/utils/_hf_folder.py b/huggingface_hub/utils/_hf_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..502b22658b44d2221b535cbd943348bb93213245 --- /dev/null +++ b/huggingface_hub/utils/_hf_folder.py @@ -0,0 +1,96 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contain helper class to retrieve/store token from/to local cache.""" + +import warnings +from pathlib import Path +from typing import Optional + +from .. import constants +from ._token import get_token + + +class HfFolder: + path_token = Path(constants.HF_TOKEN_PATH) + # Private attribute. Will be removed in v0.15 + _old_path_token = Path(constants._OLD_HF_TOKEN_PATH) + + # TODO: deprecate when adapted in transformers/datasets/gradio + # @_deprecate_method(version="1.0", message="Use `huggingface_hub.login` instead.") + @classmethod + def save_token(cls, token: str) -> None: + """ + Save token, creating folder as needed. + + Token is saved in the huggingface home folder. You can configure it by setting + the `HF_HOME` environment variable. + + Args: + token (`str`): + The token to save to the [`HfFolder`] + """ + cls.path_token.parent.mkdir(parents=True, exist_ok=True) + cls.path_token.write_text(token) + + # TODO: deprecate when adapted in transformers/datasets/gradio + # @_deprecate_method(version="1.0", message="Use `huggingface_hub.get_token` instead.") + @classmethod + def get_token(cls) -> Optional[str]: + """ + Get token or None if not existent. + + This method is deprecated in favor of [`huggingface_hub.get_token`] but is kept for backward compatibility. + Its behavior is the same as [`huggingface_hub.get_token`]. + + Returns: + `str` or `None`: The token, `None` if it doesn't exist. + """ + # 0. Check if token exist in old path but not new location + try: + cls._copy_to_new_path_and_warn() + except Exception: # if not possible (e.g. PermissionError), do not raise + pass + + return get_token() + + # TODO: deprecate when adapted in transformers/datasets/gradio + # @_deprecate_method(version="1.0", message="Use `huggingface_hub.logout` instead.") + @classmethod + def delete_token(cls) -> None: + """ + Deletes the token from storage. Does not fail if token does not exist. + """ + try: + cls.path_token.unlink() + except FileNotFoundError: + pass + + try: + cls._old_path_token.unlink() + except FileNotFoundError: + pass + + @classmethod + def _copy_to_new_path_and_warn(cls): + if cls._old_path_token.exists() and not cls.path_token.exists(): + cls.save_token(cls._old_path_token.read_text()) + warnings.warn( + f"A token has been found in `{cls._old_path_token}`. This is the old" + " path where tokens were stored. The new location is" + f" `{cls.path_token}` which is configurable using `HF_HOME` environment" + " variable. Your token has been copied to this new location. You can" + " now safely delete the old token file manually or use" + " `huggingface-cli logout`." + ) diff --git a/huggingface_hub/utils/_http.py b/huggingface_hub/utils/_http.py new file mode 100644 index 0000000000000000000000000000000000000000..076ae557e1ad44ad70eb95921b1ab709322372fb --- /dev/null +++ b/huggingface_hub/utils/_http.py @@ -0,0 +1,545 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle HTTP requests in Huggingface Hub.""" + +import io +import os +import re +import threading +import time +import uuid +from functools import lru_cache +from http import HTTPStatus +from typing import Callable, Optional, Tuple, Type, Union + +import requests +from requests import HTTPError, Response +from requests.adapters import HTTPAdapter +from requests.models import PreparedRequest + +from huggingface_hub.errors import OfflineModeIsEnabled + +from .. import constants +from ..errors import ( + BadRequestError, + DisabledRepoError, + EntryNotFoundError, + GatedRepoError, + HfHubHTTPError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from . import logging +from ._fixes import JSONDecodeError +from ._lfs import SliceFileObj +from ._typing import HTTP_METHOD_T + + +logger = logging.get_logger(__name__) + +# Both headers are used by the Hub to debug failed requests. +# `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB. +# If `X_AMZN_TRACE_ID` is set, the Hub will use it as well. +X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" +X_REQUEST_ID = "x-request-id" + +REPO_API_REGEX = re.compile( + r""" + # staging or production endpoint + ^https://[^/]+ + ( + # on /api/repo_type/repo_id + /api/(models|datasets|spaces)/(.+) + | + # or /repo_id/resolve/revision/... + /(.+)/resolve/(.+) + ) + """, + flags=re.VERBOSE, +) + + +class UniqueRequestIdAdapter(HTTPAdapter): + X_AMZN_TRACE_ID = "X-Amzn-Trace-Id" + + def add_headers(self, request, **kwargs): + super().add_headers(request, **kwargs) + + # Add random request ID => easier for server-side debug + if X_AMZN_TRACE_ID not in request.headers: + request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4()) + + # Add debug log + has_token = str(request.headers.get("authorization", "")).startswith("Bearer hf_") + logger.debug( + f"Request {request.headers[X_AMZN_TRACE_ID]}: {request.method} {request.url} (authenticated: {has_token})" + ) + + def send(self, request: PreparedRequest, *args, **kwargs) -> Response: + """Catch any RequestException to append request id to the error message for debugging.""" + try: + return super().send(request, *args, **kwargs) + except requests.RequestException as e: + request_id = request.headers.get(X_AMZN_TRACE_ID) + if request_id is not None: + # Taken from https://stackoverflow.com/a/58270258 + e.args = (*e.args, f"(Request ID: {request_id})") + raise + + +class OfflineAdapter(HTTPAdapter): + def send(self, request: PreparedRequest, *args, **kwargs) -> Response: + raise OfflineModeIsEnabled( + f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable." + ) + + +def _default_backend_factory() -> requests.Session: + session = requests.Session() + if constants.HF_HUB_OFFLINE: + session.mount("http://", OfflineAdapter()) + session.mount("https://", OfflineAdapter()) + else: + session.mount("http://", UniqueRequestIdAdapter()) + session.mount("https://", UniqueRequestIdAdapter()) + return session + + +BACKEND_FACTORY_T = Callable[[], requests.Session] +_GLOBAL_BACKEND_FACTORY: BACKEND_FACTORY_T = _default_backend_factory + + +def configure_http_backend(backend_factory: BACKEND_FACTORY_T = _default_backend_factory) -> None: + """ + Configure the HTTP backend by providing a `backend_factory`. Any HTTP calls made by `huggingface_hub` will use a + Session object instantiated by this factory. This can be useful if you are running your scripts in a specific + environment requiring custom configuration (e.g. custom proxy or certifications). + + Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, + `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` + set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between + calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. + + See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. + + Example: + ```py + import requests + from huggingface_hub import configure_http_backend, get_session + + # Create a factory function that returns a Session with configured proxies + def backend_factory() -> requests.Session: + session = requests.Session() + session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} + return session + + # Set it as the default session factory + configure_http_backend(backend_factory=backend_factory) + + # In practice, this is mostly done internally in `huggingface_hub` + session = get_session() + ``` + """ + global _GLOBAL_BACKEND_FACTORY + _GLOBAL_BACKEND_FACTORY = backend_factory + reset_sessions() + + +def get_session() -> requests.Session: + """ + Get a `requests.Session` object, using the session factory from the user. + + Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe, + `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory` + set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between + calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned. + + See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`. + + Example: + ```py + import requests + from huggingface_hub import configure_http_backend, get_session + + # Create a factory function that returns a Session with configured proxies + def backend_factory() -> requests.Session: + session = requests.Session() + session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"} + return session + + # Set it as the default session factory + configure_http_backend(backend_factory=backend_factory) + + # In practice, this is mostly done internally in `huggingface_hub` + session = get_session() + ``` + """ + return _get_session_from_cache(process_id=os.getpid(), thread_id=threading.get_ident()) + + +def reset_sessions() -> None: + """Reset the cache of sessions. + + Mostly used internally when sessions are reconfigured or an SSLError is raised. + See [`configure_http_backend`] for more details. + """ + _get_session_from_cache.cache_clear() + + +@lru_cache +def _get_session_from_cache(process_id: int, thread_id: int) -> requests.Session: + """ + Create a new session per thread using global factory. Using LRU cache (maxsize 128) to avoid memory leaks when + using thousands of threads. Cache is cleared when `configure_http_backend` is called. + """ + return _GLOBAL_BACKEND_FACTORY() + + +def http_backoff( + method: HTTP_METHOD_T, + url: str, + *, + max_retries: int = 5, + base_wait_time: float = 1, + max_wait_time: float = 8, + retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = ( + requests.Timeout, + requests.ConnectionError, + ), + retry_on_status_codes: Union[int, Tuple[int, ...]] = HTTPStatus.SERVICE_UNAVAILABLE, + **kwargs, +) -> Response: + """Wrapper around requests to retry calls on an endpoint, with exponential backoff. + + Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...) + and/or on specific status codes (ex: service unavailable). If the call failed more + than `max_retries`, the exception is thrown or `raise_for_status` is called on the + response object. + + Re-implement mechanisms from the `backoff` library to avoid adding an external + dependencies to `hugging_face_hub`. See https://github.com/litl/backoff. + + Args: + method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`): + HTTP method to perform. + url (`str`): + The URL of the resource to fetch. + max_retries (`int`, *optional*, defaults to `5`): + Maximum number of retries, defaults to 5 (no retries). + base_wait_time (`float`, *optional*, defaults to `1`): + Duration (in seconds) to wait before retrying the first time. + Wait time between retries then grows exponentially, capped by + `max_wait_time`. + max_wait_time (`float`, *optional*, defaults to `8`): + Maximum duration (in seconds) to wait before retrying. + retry_on_exceptions (`Type[Exception]` or `Tuple[Type[Exception]]`, *optional*): + Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types. + By default, retry on `requests.Timeout` and `requests.ConnectionError`. + retry_on_status_codes (`int` or `Tuple[int]`, *optional*, defaults to `503`): + Define on which status codes the request must be retried. By default, only + HTTP 503 Service Unavailable is retried. + **kwargs (`dict`, *optional*): + kwargs to pass to `requests.request`. + + Example: + ``` + >>> from huggingface_hub.utils import http_backoff + + # Same usage as "requests.request". + >>> response = http_backoff("GET", "https://www.google.com") + >>> response.raise_for_status() + + # If you expect a Gateway Timeout from time to time + >>> http_backoff("PUT", upload_url, data=data, retry_on_status_codes=504) + >>> response.raise_for_status() + ``` + ++ + When using `requests` it is possible to stream data by passing an iterator to the + `data` argument. On http backoff this is a problem as the iterator is not reset + after a failed call. This issue is mitigated for file objects or any IO streams + by saving the initial position of the cursor (with `data.tell()`) and resetting the + cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff + will fail. If this is a hard constraint for you, please let us know by opening an + issue on [Github](https://github.com/huggingface/huggingface_hub). + + + """ + if isinstance(retry_on_exceptions, type): # Tuple from single exception type + retry_on_exceptions = (retry_on_exceptions,) + + if isinstance(retry_on_status_codes, int): # Tuple from single status code + retry_on_status_codes = (retry_on_status_codes,) + + nb_tries = 0 + sleep_time = base_wait_time + + # If `data` is used and is a file object (or any IO), it will be consumed on the + # first HTTP request. We need to save the initial position so that the full content + # of the file is re-sent on http backoff. See warning tip in docstring. + io_obj_initial_pos = None + if "data" in kwargs and isinstance(kwargs["data"], (io.IOBase, SliceFileObj)): + io_obj_initial_pos = kwargs["data"].tell() + + session = get_session() + while True: + nb_tries += 1 + try: + # If `data` is used and is a file object (or any IO), set back cursor to + # initial position. + if io_obj_initial_pos is not None: + kwargs["data"].seek(io_obj_initial_pos) + + # Perform request and return if status_code is not in the retry list. + response = session.request(method=method, url=url, **kwargs) + if response.status_code not in retry_on_status_codes: + return response + + # Wrong status code returned (HTTP 503 for instance) + logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}") + if nb_tries > max_retries: + response.raise_for_status() # Will raise uncaught exception + # We return response to avoid infinite loop in the corner case where the + # user ask for retry on a status code that doesn't raise_for_status. + return response + + except retry_on_exceptions as err: + logger.warning(f"'{err}' thrown while requesting {method} {url}") + + if isinstance(err, requests.ConnectionError): + reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects + + if nb_tries > max_retries: + raise err + + # Sleep for X seconds + logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].") + time.sleep(sleep_time) + + # Update sleep time for next retry + sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff + + +def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str: + """Replace the default endpoint in a URL by a custom one. + + This is useful when using a proxy and the Hugging Face Hub returns a URL with the default endpoint. + """ + endpoint = endpoint or constants.ENDPOINT + # check if a proxy has been set => if yes, update the returned URL to use the proxy + if endpoint not in (None, constants._HF_DEFAULT_ENDPOINT, constants._HF_DEFAULT_STAGING_ENDPOINT): + url = url.replace(constants._HF_DEFAULT_ENDPOINT, endpoint) + url = url.replace(constants._HF_DEFAULT_STAGING_ENDPOINT, endpoint) + return url + + +def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) -> None: + """ + Internal version of `response.raise_for_status()` that will refine a + potential HTTPError. Raised exception will be an instance of `HfHubHTTPError`. + + This helper is meant to be the unique method to raise_for_status when making a call + to the Hugging Face Hub. + + + Example: + ```py + import requests + from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError + + response = get_session().post(...) + try: + hf_raise_for_status(response) + except HfHubHTTPError as e: + print(str(e)) # formatted message + e.request_id, e.server_message # details returned by server + + # Complete the error message with additional information once it's raised + e.append_to_message("\n`create_commit` expects the repository to exist.") + raise + ``` + + Args: + response (`Response`): + Response from the server. + endpoint_name (`str`, *optional*): + Name of the endpoint that has been called. If provided, the error message + will be more complete. + ++ + Raises when the request has failed: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it + doesn't exist, because `repo_type` is not set correctly, or because the repo + is `private` and you do not have access. + - [`~utils.GatedRepoError`] + If the repository exists but is gated and the user is not on the authorized + list. + - [`~utils.RevisionNotFoundError`] + If the repository exists but the revision couldn't be find. + - [`~utils.EntryNotFoundError`] + If the repository exists but the entry (e.g. the requested file) couldn't be + find. + - [`~utils.BadRequestError`] + If request failed with a HTTP 400 BadRequest error. + - [`~utils.HfHubHTTPError`] + If request failed for a reason not listed above. + + + """ + try: + response.raise_for_status() + except HTTPError as e: + error_code = response.headers.get("X-Error-Code") + error_message = response.headers.get("X-Error-Message") + + if error_code == "RevisionNotFound": + message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}." + raise _format(RevisionNotFoundError, message, response) from e + + elif error_code == "EntryNotFound": + message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}." + raise _format(EntryNotFoundError, message, response) from e + + elif error_code == "GatedRepo": + message = ( + f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}." + ) + raise _format(GatedRepoError, message, response) from e + + elif error_message == "Access to this resource is disabled.": + message = ( + f"{response.status_code} Client Error." + + "\n\n" + + f"Cannot access repository for url {response.url}." + + "\n" + + "Access to this resource is disabled." + ) + raise _format(DisabledRepoError, message, response) from e + + elif error_code == "RepoNotFound" or ( + response.status_code == 401 + and response.request is not None + and response.request.url is not None + and REPO_API_REGEX.search(response.request.url) is not None + ): + # 401 is misleading as it is returned for: + # - private and gated repos if user is not authenticated + # - missing repos + # => for now, we process them as `RepoNotFound` anyway. + # See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9 + message = ( + f"{response.status_code} Client Error." + + "\n\n" + + f"Repository Not Found for url: {response.url}." + + "\nPlease make sure you specified the correct `repo_id` and" + " `repo_type`.\nIf you are trying to access a private or gated repo," + " make sure you are authenticated." + ) + raise _format(RepositoryNotFoundError, message, response) from e + + elif response.status_code == 400: + message = ( + f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:" + ) + raise _format(BadRequestError, message, response) from e + + elif response.status_code == 403: + message = ( + f"\n\n{response.status_code} Forbidden: {error_message}." + + f"\nCannot access content at: {response.url}." + + "\nMake sure your token has the correct permissions." + ) + raise _format(HfHubHTTPError, message, response) from e + + elif response.status_code == 416: + range_header = response.request.headers.get("Range") + message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}." + raise _format(HfHubHTTPError, message, response) from e + + # Convert `HTTPError` into a `HfHubHTTPError` to display request information + # as well (request id and/or server error message) + raise _format(HfHubHTTPError, str(e), response) from e + + +def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Response) -> HfHubHTTPError: + server_errors = [] + + # Retrieve server error from header + from_headers = response.headers.get("X-Error-Message") + if from_headers is not None: + server_errors.append(from_headers) + + # Retrieve server error from body + try: + # Case errors are returned in a JSON format + data = response.json() + + error = data.get("error") + if error is not None: + if isinstance(error, list): + # Case {'error': ['my error 1', 'my error 2']} + server_errors.extend(error) + else: + # Case {'error': 'my error'} + server_errors.append(error) + + errors = data.get("errors") + if errors is not None: + # Case {'errors': [{'message': 'my error 1'}, {'message': 'my error 2'}]} + for error in errors: + if "message" in error: + server_errors.append(error["message"]) + + except JSONDecodeError: + # Case error is directly returned as text + if response.text: + server_errors.append(response.text) + + # Strip all server messages + server_errors = [line.strip() for line in server_errors if line.strip()] + + # Deduplicate server messages (keep order) + # taken from https://stackoverflow.com/a/17016257 + server_errors = list(dict.fromkeys(server_errors)) + + # Format server error + server_message = "\n".join(server_errors) + + # Add server error to custom message + final_error_message = custom_message + if server_message and server_message.lower() not in custom_message.lower(): + if "\n\n" in custom_message: + final_error_message += "\n" + server_message + else: + final_error_message += "\n\n" + server_message + + # Add Request ID + request_id = str(response.headers.get(X_REQUEST_ID, "")) + if len(request_id) > 0 and request_id.lower() not in final_error_message.lower(): + request_id_message = f" (Request ID: {request_id})" + if "\n" in final_error_message: + newline_index = final_error_message.index("\n") + final_error_message = ( + final_error_message[:newline_index] + request_id_message + final_error_message[newline_index:] + ) + else: + final_error_message += request_id_message + + # Return + return error_type(final_error_message.strip(), response=response, server_message=server_message or None) diff --git a/huggingface_hub/utils/_lfs.py b/huggingface_hub/utils/_lfs.py new file mode 100644 index 0000000000000000000000000000000000000000..307f371ffa79a8ae726ee03458c52e230a792898 --- /dev/null +++ b/huggingface_hub/utils/_lfs.py @@ -0,0 +1,110 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Git LFS related utilities""" + +import io +import os +from contextlib import AbstractContextManager +from typing import BinaryIO + + +class SliceFileObj(AbstractContextManager): + """ + Utility context manager to read a *slice* of a seekable file-like object as a seekable, file-like object. + + This is NOT thread safe + + Inspired by stackoverflow.com/a/29838711/593036 + + Credits to @julien-c + + Args: + fileobj (`BinaryIO`): + A file-like object to slice. MUST implement `tell()` and `seek()` (and `read()` of course). + `fileobj` will be reset to its original position when exiting the context manager. + seek_from (`int`): + The start of the slice (offset from position 0 in bytes). + read_limit (`int`): + The maximum number of bytes to read from the slice. + + Attributes: + previous_position (`int`): + The previous position + + Examples: + + Reading 200 bytes with an offset of 128 bytes from a file (ie bytes 128 to 327): + ```python + >>> with open("path/to/file", "rb") as file: + ... with SliceFileObj(file, seek_from=128, read_limit=200) as fslice: + ... fslice.read(...) + ``` + + Reading a file in chunks of 512 bytes + ```python + >>> import os + >>> chunk_size = 512 + >>> file_size = os.getsize("path/to/file") + >>> with open("path/to/file", "rb") as file: + ... for chunk_idx in range(ceil(file_size / chunk_size)): + ... with SliceFileObj(file, seek_from=chunk_idx * chunk_size, read_limit=chunk_size) as fslice: + ... chunk = fslice.read(...) + + ``` + """ + + def __init__(self, fileobj: BinaryIO, seek_from: int, read_limit: int): + self.fileobj = fileobj + self.seek_from = seek_from + self.read_limit = read_limit + + def __enter__(self): + self._previous_position = self.fileobj.tell() + end_of_stream = self.fileobj.seek(0, os.SEEK_END) + self._len = min(self.read_limit, end_of_stream - self.seek_from) + # ^^ The actual number of bytes that can be read from the slice + self.fileobj.seek(self.seek_from, io.SEEK_SET) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.fileobj.seek(self._previous_position, io.SEEK_SET) + + def read(self, n: int = -1): + pos = self.tell() + if pos >= self._len: + return b"" + remaining_amount = self._len - pos + data = self.fileobj.read(remaining_amount if n < 0 else min(n, remaining_amount)) + return data + + def tell(self) -> int: + return self.fileobj.tell() - self.seek_from + + def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: + start = self.seek_from + end = start + self._len + if whence in (os.SEEK_SET, os.SEEK_END): + offset = start + offset if whence == os.SEEK_SET else end + offset + offset = max(start, min(offset, end)) + whence = os.SEEK_SET + elif whence == os.SEEK_CUR: + cur_pos = self.fileobj.tell() + offset = max(start - cur_pos, min(offset, end - cur_pos)) + else: + raise ValueError(f"whence value {whence} is not supported") + return self.fileobj.seek(offset, whence) - self.seek_from + + def __iter__(self): + yield self.read(n=4 * 1024 * 1024) diff --git a/huggingface_hub/utils/_pagination.py b/huggingface_hub/utils/_pagination.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ab4fe7cba9bd13f01d9c81854a00fd30b7f0d9 --- /dev/null +++ b/huggingface_hub/utils/_pagination.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle pagination on Huggingface Hub.""" + +from typing import Dict, Iterable, Optional + +import requests + +from . import get_session, hf_raise_for_status, logging + + +logger = logging.get_logger(__name__) + + +def paginate(path: str, params: Dict, headers: Dict) -> Iterable: + """Fetch a list of models/datasets/spaces and paginate through results. + + This is using the same "Link" header format as GitHub. + See: + - https://requests.readthedocs.io/en/latest/api/#requests.Response.links + - https://docs.github.com/en/rest/guides/traversing-with-pagination#link-header + """ + session = get_session() + r = session.get(path, params=params, headers=headers) + hf_raise_for_status(r) + yield from r.json() + + # Follow pages + # Next link already contains query params + next_page = _get_next_page(r) + while next_page is not None: + logger.debug(f"Pagination detected. Requesting next page: {next_page}") + r = session.get(next_page, headers=headers) + hf_raise_for_status(r) + yield from r.json() + next_page = _get_next_page(r) + + +def _get_next_page(response: requests.Response) -> Optional[str]: + return response.links.get("next", {}).get("url") diff --git a/huggingface_hub/utils/_paths.py b/huggingface_hub/utils/_paths.py new file mode 100644 index 0000000000000000000000000000000000000000..4f2c0ebce070bbde4900e919a3aca7cfc331e747 --- /dev/null +++ b/huggingface_hub/utils/_paths.py @@ -0,0 +1,141 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to handle paths in Huggingface Hub.""" + +from fnmatch import fnmatch +from pathlib import Path +from typing import Callable, Generator, Iterable, List, Optional, TypeVar, Union + + +T = TypeVar("T") + +# Always ignore `.git` and `.cache/huggingface` folders in commits +DEFAULT_IGNORE_PATTERNS = [ + ".git", + ".git/*", + "*/.git", + "**/.git/**", + ".cache/huggingface", + ".cache/huggingface/*", + "*/.cache/huggingface", + "**/.cache/huggingface/**", +] +# Forbidden to commit these folders +FORBIDDEN_FOLDERS = [".git", ".cache"] + + +def filter_repo_objects( + items: Iterable[T], + *, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + key: Optional[Callable[[T], str]] = None, +) -> Generator[T, None, None]: + """Filter repo objects based on an allowlist and a denylist. + + Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects. + In the later case, `key` must be provided and specifies a function of one argument + that is used to extract a path from each element in iterable. + + Patterns are Unix shell-style wildcards which are NOT regular expressions. See + https://docs.python.org/3/library/fnmatch.html for more details. + + Args: + items (`Iterable`): + List of items to filter. + allow_patterns (`str` or `List[str]`, *optional*): + Patterns constituting the allowlist. If provided, item paths must match at + least one pattern from the allowlist. + ignore_patterns (`str` or `List[str]`, *optional*): + Patterns constituting the denylist. If provided, item paths must not match + any patterns from the denylist. + key (`Callable[[T], str]`, *optional*): + Single-argument function to extract a path from each item. If not provided, + the `items` must already be `str` or `Path`. + + Returns: + Filtered list of objects, as a generator. + + Raises: + :class:`ValueError`: + If `key` is not provided and items are not `str` or `Path`. + + Example usage with paths: + ```python + >>> # Filter only PDFs that are not hidden. + >>> list(filter_repo_objects( + ... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"], + ... allow_patterns=["*.pdf"], + ... ignore_patterns=[".*"], + ... )) + ["aaa.pdf"] + ``` + + Example usage with objects: + ```python + >>> list(filter_repo_objects( + ... [ + ... CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf") + ... CommitOperationAdd(path_or_fileobj="/tmp/bbb.jpg", path_in_repo="bbb.jpg") + ... CommitOperationAdd(path_or_fileobj="/tmp/.ccc.pdf", path_in_repo=".ccc.pdf") + ... CommitOperationAdd(path_or_fileobj="/tmp/.ddd.png", path_in_repo=".ddd.png") + ... ], + ... allow_patterns=["*.pdf"], + ... ignore_patterns=[".*"], + ... key=lambda x: x.repo_in_path + ... )) + [CommitOperationAdd(path_or_fileobj="/tmp/aaa.pdf", path_in_repo="aaa.pdf")] + ``` + """ + if isinstance(allow_patterns, str): + allow_patterns = [allow_patterns] + + if isinstance(ignore_patterns, str): + ignore_patterns = [ignore_patterns] + + if allow_patterns is not None: + allow_patterns = [_add_wildcard_to_directories(p) for p in allow_patterns] + if ignore_patterns is not None: + ignore_patterns = [_add_wildcard_to_directories(p) for p in ignore_patterns] + + if key is None: + + def _identity(item: T) -> str: + if isinstance(item, str): + return item + if isinstance(item, Path): + return str(item) + raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.") + + key = _identity # Items must be `str` or `Path`, otherwise raise ValueError + + for item in items: + path = key(item) + + # Skip if there's an allowlist and path doesn't match any + if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns): + continue + + # Skip if there's a denylist and path matches any + if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns): + continue + + yield item + + +def _add_wildcard_to_directories(pattern: str) -> str: + if pattern[-1] == "/": + return pattern + "*" + return pattern diff --git a/huggingface_hub/utils/_runtime.py b/huggingface_hub/utils/_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd19e6716e3735de86dd923feee80a8e8c4c80a --- /dev/null +++ b/huggingface_hub/utils/_runtime.py @@ -0,0 +1,388 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Check presence of installed packages at runtime.""" + +import importlib.metadata +import os +import platform +import sys +import warnings +from typing import Any, Dict + +from .. import __version__, constants + + +_PY_VERSION: str = sys.version.split()[0].rstrip("+") + +_package_versions = {} + +_CANDIDATES = { + "aiohttp": {"aiohttp"}, + "fastai": {"fastai"}, + "fastapi": {"fastapi"}, + "fastcore": {"fastcore"}, + "gradio": {"gradio"}, + "graphviz": {"graphviz"}, + "hf_transfer": {"hf_transfer"}, + "jinja": {"Jinja2"}, + "keras": {"keras"}, + "minijinja": {"minijinja"}, + "numpy": {"numpy"}, + "pillow": {"Pillow"}, + "pydantic": {"pydantic"}, + "pydot": {"pydot"}, + "safetensors": {"safetensors"}, + "tensorboard": {"tensorboardX"}, + "tensorflow": ( + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "intel-tensorflow", + "intel-tensorflow-avx512", + "tensorflow-rocm", + "tensorflow-macos", + ), + "torch": {"torch"}, +} + +# Check once at runtime +for candidate_name, package_names in _CANDIDATES.items(): + _package_versions[candidate_name] = "N/A" + for name in package_names: + try: + _package_versions[candidate_name] = importlib.metadata.version(name) + break + except importlib.metadata.PackageNotFoundError: + pass + + +def _get_version(package_name: str) -> str: + return _package_versions.get(package_name, "N/A") + + +def is_package_available(package_name: str) -> bool: + return _get_version(package_name) != "N/A" + + +# Python +def get_python_version() -> str: + return _PY_VERSION + + +# Huggingface Hub +def get_hf_hub_version() -> str: + return __version__ + + +# aiohttp +def is_aiohttp_available() -> bool: + return is_package_available("aiohttp") + + +def get_aiohttp_version() -> str: + return _get_version("aiohttp") + + +# FastAI +def is_fastai_available() -> bool: + return is_package_available("fastai") + + +def get_fastai_version() -> str: + return _get_version("fastai") + + +# FastAPI +def is_fastapi_available() -> bool: + return is_package_available("fastapi") + + +def get_fastapi_version() -> str: + return _get_version("fastapi") + + +# Fastcore +def is_fastcore_available() -> bool: + return is_package_available("fastcore") + + +def get_fastcore_version() -> str: + return _get_version("fastcore") + + +# FastAI +def is_gradio_available() -> bool: + return is_package_available("gradio") + + +def get_gradio_version() -> str: + return _get_version("gradio") + + +# Graphviz +def is_graphviz_available() -> bool: + return is_package_available("graphviz") + + +def get_graphviz_version() -> str: + return _get_version("graphviz") + + +# hf_transfer +def is_hf_transfer_available() -> bool: + return is_package_available("hf_transfer") + + +def get_hf_transfer_version() -> str: + return _get_version("hf_transfer") + + +# keras +def is_keras_available() -> bool: + return is_package_available("keras") + + +def get_keras_version() -> str: + return _get_version("keras") + + +# Minijinja +def is_minijinja_available() -> bool: + return is_package_available("minijinja") + + +def get_minijinja_version() -> str: + return _get_version("minijinja") + + +# Numpy +def is_numpy_available() -> bool: + return is_package_available("numpy") + + +def get_numpy_version() -> str: + return _get_version("numpy") + + +# Jinja +def is_jinja_available() -> bool: + return is_package_available("jinja") + + +def get_jinja_version() -> str: + return _get_version("jinja") + + +# Pillow +def is_pillow_available() -> bool: + return is_package_available("pillow") + + +def get_pillow_version() -> str: + return _get_version("pillow") + + +# Pydantic +def is_pydantic_available() -> bool: + if not is_package_available("pydantic"): + return False + # For Pydantic, we add an extra check to test whether it is correctly installed or not. If both pydantic 2.x and + # typing_extensions<=4.5.0 are installed, then pydantic will fail at import time. This should not happen when + # it is installed with `pip install huggingface_hub[inference]` but it can happen when it is installed manually + # by the user in an environment that we don't control. + # + # Usually we won't need to do this kind of check on optional dependencies. However, pydantic is a special case + # as it is automatically imported when doing `from huggingface_hub import ...` even if the user doesn't use it. + # + # See https://github.com/huggingface/huggingface_hub/pull/1829 for more details. + try: + from pydantic import validator # noqa: F401 + except ImportError: + # Example: "ImportError: cannot import name 'TypeAliasType' from 'typing_extensions'" + warnings.warn( + "Pydantic is installed but cannot be imported. Please check your installation. `huggingface_hub` will " + "default to not using Pydantic. Error message: '{e}'" + ) + return False + return True + + +def get_pydantic_version() -> str: + return _get_version("pydantic") + + +# Pydot +def is_pydot_available() -> bool: + return is_package_available("pydot") + + +def get_pydot_version() -> str: + return _get_version("pydot") + + +# Tensorboard +def is_tensorboard_available() -> bool: + return is_package_available("tensorboard") + + +def get_tensorboard_version() -> str: + return _get_version("tensorboard") + + +# Tensorflow +def is_tf_available() -> bool: + return is_package_available("tensorflow") + + +def get_tf_version() -> str: + return _get_version("tensorflow") + + +# Torch +def is_torch_available() -> bool: + return is_package_available("torch") + + +def get_torch_version() -> str: + return _get_version("torch") + + +# Safetensors +def is_safetensors_available() -> bool: + return is_package_available("safetensors") + + +# Shell-related helpers +try: + # Set to `True` if script is running in a Google Colab notebook. + # If running in Google Colab, git credential store is set globally which makes the + # warning disappear. See https://github.com/huggingface/huggingface_hub/issues/1043 + # + # Taken from https://stackoverflow.com/a/63519730. + _is_google_colab = "google.colab" in str(get_ipython()) # type: ignore # noqa: F821 +except NameError: + _is_google_colab = False + + +def is_notebook() -> bool: + """Return `True` if code is executed in a notebook (Jupyter, Colab, QTconsole). + + Taken from https://stackoverflow.com/a/39662359. + Adapted to make it work with Google colab as well. + """ + try: + shell_class = get_ipython().__class__ # type: ignore # noqa: F821 + for parent_class in shell_class.__mro__: # e.g. "is subclass of" + if parent_class.__name__ == "ZMQInteractiveShell": + return True # Jupyter notebook, Google colab or qtconsole + return False + except NameError: + return False # Probably standard Python interpreter + + +def is_google_colab() -> bool: + """Return `True` if code is executed in a Google colab. + + Taken from https://stackoverflow.com/a/63519730. + """ + return _is_google_colab + + +def is_colab_enterprise() -> bool: + """Return `True` if code is executed in a Google Colab Enterprise environment.""" + return os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE" + + +def dump_environment_info() -> Dict[str, Any]: + """Dump information about the machine to help debugging issues. + + Similar helper exist in: + - `datasets` (https://github.com/huggingface/datasets/blob/main/src/datasets/commands/env.py) + - `diffusers` (https://github.com/huggingface/diffusers/blob/main/src/diffusers/commands/env.py) + - `transformers` (https://github.com/huggingface/transformers/blob/main/src/transformers/commands/env.py) + """ + from huggingface_hub import get_token, whoami + from huggingface_hub.utils import list_credential_helpers + + token = get_token() + + # Generic machine info + info: Dict[str, Any] = { + "huggingface_hub version": get_hf_hub_version(), + "Platform": platform.platform(), + "Python version": get_python_version(), + } + + # Interpreter info + try: + shell_class = get_ipython().__class__ # type: ignore # noqa: F821 + info["Running in iPython ?"] = "Yes" + info["iPython shell"] = shell_class.__name__ + except NameError: + info["Running in iPython ?"] = "No" + info["Running in notebook ?"] = "Yes" if is_notebook() else "No" + info["Running in Google Colab ?"] = "Yes" if is_google_colab() else "No" + info["Running in Google Colab Enterprise ?"] = "Yes" if is_colab_enterprise() else "No" + # Login info + info["Token path ?"] = constants.HF_TOKEN_PATH + info["Has saved token ?"] = token is not None + if token is not None: + try: + info["Who am I ?"] = whoami()["name"] + except Exception: + pass + + try: + info["Configured git credential helpers"] = ", ".join(list_credential_helpers()) + except Exception: + pass + + # Installed dependencies + info["FastAI"] = get_fastai_version() + info["Tensorflow"] = get_tf_version() + info["Torch"] = get_torch_version() + info["Jinja2"] = get_jinja_version() + info["Graphviz"] = get_graphviz_version() + info["keras"] = get_keras_version() + info["Pydot"] = get_pydot_version() + info["Pillow"] = get_pillow_version() + info["hf_transfer"] = get_hf_transfer_version() + info["gradio"] = get_gradio_version() + info["tensorboard"] = get_tensorboard_version() + info["numpy"] = get_numpy_version() + info["pydantic"] = get_pydantic_version() + info["aiohttp"] = get_aiohttp_version() + + # Environment variables + info["ENDPOINT"] = constants.ENDPOINT + info["HF_HUB_CACHE"] = constants.HF_HUB_CACHE + info["HF_ASSETS_CACHE"] = constants.HF_ASSETS_CACHE + info["HF_TOKEN_PATH"] = constants.HF_TOKEN_PATH + info["HF_HUB_OFFLINE"] = constants.HF_HUB_OFFLINE + info["HF_HUB_DISABLE_TELEMETRY"] = constants.HF_HUB_DISABLE_TELEMETRY + info["HF_HUB_DISABLE_PROGRESS_BARS"] = constants.HF_HUB_DISABLE_PROGRESS_BARS + info["HF_HUB_DISABLE_SYMLINKS_WARNING"] = constants.HF_HUB_DISABLE_SYMLINKS_WARNING + info["HF_HUB_DISABLE_EXPERIMENTAL_WARNING"] = constants.HF_HUB_DISABLE_EXPERIMENTAL_WARNING + info["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = constants.HF_HUB_DISABLE_IMPLICIT_TOKEN + info["HF_HUB_ENABLE_HF_TRANSFER"] = constants.HF_HUB_ENABLE_HF_TRANSFER + info["HF_HUB_ETAG_TIMEOUT"] = constants.HF_HUB_ETAG_TIMEOUT + info["HF_HUB_DOWNLOAD_TIMEOUT"] = constants.HF_HUB_DOWNLOAD_TIMEOUT + + print("\nCopy-and-paste the text below in your GitHub issue.\n") + print("\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + "\n") + return info diff --git a/huggingface_hub/utils/_safetensors.py b/huggingface_hub/utils/_safetensors.py new file mode 100644 index 0000000000000000000000000000000000000000..38546c6d34db786c62861e1706f747a21b7012bf --- /dev/null +++ b/huggingface_hub/utils/_safetensors.py @@ -0,0 +1,111 @@ +import functools +import operator +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, List, Literal, Optional, Tuple + + +FILENAME_T = str +TENSOR_NAME_T = str +DTYPE_T = Literal["F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"] + + +@dataclass +class TensorInfo: + """Information about a tensor. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Attributes: + dtype (`str`): + The data type of the tensor ("F64", "F32", "F16", "BF16", "I64", "I32", "I16", "I8", "U8", "BOOL"). + shape (`List[int]`): + The shape of the tensor. + data_offsets (`Tuple[int, int]`): + The offsets of the data in the file as a tuple `[BEGIN, END]`. + parameter_count (`int`): + The number of parameters in the tensor. + """ + + dtype: DTYPE_T + shape: List[int] + data_offsets: Tuple[int, int] + parameter_count: int = field(init=False) + + def __post_init__(self) -> None: + # Taken from https://stackoverflow.com/a/13840436 + try: + self.parameter_count = functools.reduce(operator.mul, self.shape) + except TypeError: + self.parameter_count = 1 # scalar value has no shape + + +@dataclass +class SafetensorsFileMetadata: + """Metadata for a Safetensors file hosted on the Hub. + + This class is returned by [`parse_safetensors_file_metadata`]. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Attributes: + metadata (`Dict`): + The metadata contained in the file. + tensors (`Dict[str, TensorInfo]`): + A map of all tensors. Keys are tensor names and values are information about the corresponding tensor, as a + [`TensorInfo`] object. + parameter_count (`Dict[str, int]`): + A map of the number of parameters per data type. Keys are data types and values are the number of parameters + of that data type. + """ + + metadata: Dict[str, str] + tensors: Dict[TENSOR_NAME_T, TensorInfo] + parameter_count: Dict[DTYPE_T, int] = field(init=False) + + def __post_init__(self) -> None: + parameter_count: Dict[DTYPE_T, int] = defaultdict(int) + for tensor in self.tensors.values(): + parameter_count[tensor.dtype] += tensor.parameter_count + self.parameter_count = dict(parameter_count) + + +@dataclass +class SafetensorsRepoMetadata: + """Metadata for a Safetensors repo. + + A repo is considered to be a Safetensors repo if it contains either a 'model.safetensors' weight file (non-shared + model) or a 'model.safetensors.index.json' index file (sharded model) at its root. + + This class is returned by [`get_safetensors_metadata`]. + + For more details regarding the safetensors format, check out https://huggingface.co/docs/safetensors/index#format. + + Attributes: + metadata (`Dict`, *optional*): + The metadata contained in the 'model.safetensors.index.json' file, if it exists. Only populated for sharded + models. + sharded (`bool`): + Whether the repo contains a sharded model or not. + weight_map (`Dict[str, str]`): + A map of all weights. Keys are tensor names and values are filenames of the files containing the tensors. + files_metadata (`Dict[str, SafetensorsFileMetadata]`): + A map of all files metadata. Keys are filenames and values are the metadata of the corresponding file, as + a [`SafetensorsFileMetadata`] object. + parameter_count (`Dict[str, int]`): + A map of the number of parameters per data type. Keys are data types and values are the number of parameters + of that data type. + """ + + metadata: Optional[Dict] + sharded: bool + weight_map: Dict[TENSOR_NAME_T, FILENAME_T] # tensor name -> filename + files_metadata: Dict[FILENAME_T, SafetensorsFileMetadata] # filename -> metadata + parameter_count: Dict[DTYPE_T, int] = field(init=False) + + def __post_init__(self) -> None: + parameter_count: Dict[DTYPE_T, int] = defaultdict(int) + for file_metadata in self.files_metadata.values(): + for dtype, nb_parameters_ in file_metadata.parameter_count.items(): + parameter_count[dtype] += nb_parameters_ + self.parameter_count = dict(parameter_count) diff --git a/huggingface_hub/utils/_subprocess.py b/huggingface_hub/utils/_subprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..dd51294a909e16e0c74af131d3a7d20f9c32fb1b --- /dev/null +++ b/huggingface_hub/utils/_subprocess.py @@ -0,0 +1,142 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Contains utilities to easily handle subprocesses in `huggingface_hub`.""" + +import os +import subprocess +import sys +from contextlib import contextmanager +from io import StringIO +from pathlib import Path +from typing import IO, Generator, List, Optional, Tuple, Union + +from .logging import get_logger + + +logger = get_logger(__name__) + + +@contextmanager +def capture_output() -> Generator[StringIO, None, None]: + """Capture output that is printed to terminal. + + Taken from https://stackoverflow.com/a/34738440 + + Example: + ```py + >>> with capture_output() as output: + ... print("hello world") + >>> assert output.getvalue() == "hello world\n" + ``` + """ + output = StringIO() + previous_output = sys.stdout + sys.stdout = output + yield output + sys.stdout = previous_output + + +def run_subprocess( + command: Union[str, List[str]], + folder: Optional[Union[str, Path]] = None, + check=True, + **kwargs, +) -> subprocess.CompletedProcess: + """ + Method to run subprocesses. Calling this will capture the `stderr` and `stdout`, + please call `subprocess.run` manually in case you would like for them not to + be captured. + + Args: + command (`str` or `List[str]`): + The command to execute as a string or list of strings. + folder (`str`, *optional*): + The folder in which to run the command. Defaults to current working + directory (from `os.getcwd()`). + check (`bool`, *optional*, defaults to `True`): + Setting `check` to `True` will raise a `subprocess.CalledProcessError` + when the subprocess has a non-zero exit code. + kwargs (`Dict[str]`): + Keyword arguments to be passed to the `subprocess.run` underlying command. + + Returns: + `subprocess.CompletedProcess`: The completed process. + """ + if isinstance(command, str): + command = command.split() + + if isinstance(folder, Path): + folder = str(folder) + + return subprocess.run( + command, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + check=check, + encoding="utf-8", + errors="replace", # if not utf-8, replace char by οΏ½ + cwd=folder or os.getcwd(), + **kwargs, + ) + + +@contextmanager +def run_interactive_subprocess( + command: Union[str, List[str]], + folder: Optional[Union[str, Path]] = None, + **kwargs, +) -> Generator[Tuple[IO[str], IO[str]], None, None]: + """Run a subprocess in an interactive mode in a context manager. + + Args: + command (`str` or `List[str]`): + The command to execute as a string or list of strings. + folder (`str`, *optional*): + The folder in which to run the command. Defaults to current working + directory (from `os.getcwd()`). + kwargs (`Dict[str]`): + Keyword arguments to be passed to the `subprocess.run` underlying command. + + Returns: + `Tuple[IO[str], IO[str]]`: A tuple with `stdin` and `stdout` to interact + with the process (input and output are utf-8 encoded). + + Example: + ```python + with _interactive_subprocess("git credential-store get") as (stdin, stdout): + # Write to stdin + stdin.write("url=hf.co\nusername=obama\n".encode("utf-8")) + stdin.flush() + + # Read from stdout + output = stdout.read().decode("utf-8") + ``` + """ + if isinstance(command, str): + command = command.split() + + with subprocess.Popen( + command, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + encoding="utf-8", + errors="replace", # if not utf-8, replace char by οΏ½ + cwd=folder or os.getcwd(), + **kwargs, + ) as process: + assert process.stdin is not None, "subprocess is opened as subprocess.PIPE" + assert process.stdout is not None, "subprocess is opened as subprocess.PIPE" + yield process.stdin, process.stdout diff --git a/huggingface_hub/utils/_telemetry.py b/huggingface_hub/utils/_telemetry.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba4a6349a8de1c565263ec73d235d36f88b68cf --- /dev/null +++ b/huggingface_hub/utils/_telemetry.py @@ -0,0 +1,126 @@ +from queue import Queue +from threading import Lock, Thread +from typing import Dict, Optional, Union +from urllib.parse import quote + +from .. import constants, logging +from . import build_hf_headers, get_session, hf_raise_for_status + + +logger = logging.get_logger(__name__) + +# Telemetry is sent by a separate thread to avoid blocking the main thread. +# A daemon thread is started once and consume tasks from the _TELEMETRY_QUEUE. +# If the thread stops for some reason -shouldn't happen-, we restart a new one. +_TELEMETRY_THREAD: Optional[Thread] = None +_TELEMETRY_THREAD_LOCK = Lock() # Lock to avoid starting multiple threads in parallel +_TELEMETRY_QUEUE: Queue = Queue() + + +def send_telemetry( + topic: str, + *, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, +) -> None: + """ + Sends telemetry that helps tracking usage of different HF libraries. + + This usage data helps us debug issues and prioritize new features. However, we understand that not everyone wants + to share additional information, and we respect your privacy. You can disable telemetry collection by setting the + `HF_HUB_DISABLE_TELEMETRY=1` as environment variable. Telemetry is also disabled in offline mode (i.e. when setting + `HF_HUB_OFFLINE=1`). + + Telemetry collection is run in a separate thread to minimize impact for the user. + + Args: + topic (`str`): + Name of the topic that is monitored. The topic is directly used to build the URL. If you want to monitor + subtopics, just use "/" separation. Examples: "gradio", "transformers/examples",... + library_name (`str`, *optional*): + The name of the library that is making the HTTP request. Will be added to the user-agent header. + library_version (`str`, *optional*): + The version of the library that is making the HTTP request. Will be added to the user-agent header. + user_agent (`str`, `dict`, *optional*): + The user agent info in the form of a dictionary or a single string. It will be completed with information about the installed packages. + + Example: + ```py + >>> from huggingface_hub.utils import send_telemetry + + # Send telemetry without library information + >>> send_telemetry("ping") + + # Send telemetry to subtopic with library information + >>> send_telemetry("gradio/local_link", library_name="gradio", library_version="3.22.1") + + # Send telemetry with additional data + >>> send_telemetry( + ... topic="examples", + ... library_name="transformers", + ... library_version="4.26.0", + ... user_agent={"pipeline": "text_classification", "framework": "flax"}, + ... ) + ``` + """ + if constants.HF_HUB_OFFLINE or constants.HF_HUB_DISABLE_TELEMETRY: + return + + _start_telemetry_thread() # starts thread only if doesn't exist yet + _TELEMETRY_QUEUE.put( + {"topic": topic, "library_name": library_name, "library_version": library_version, "user_agent": user_agent} + ) + + +def _start_telemetry_thread(): + """Start a daemon thread to consume tasks from the telemetry queue. + + If the thread is interrupted, start a new one. + """ + with _TELEMETRY_THREAD_LOCK: # avoid to start multiple threads if called concurrently + global _TELEMETRY_THREAD + if _TELEMETRY_THREAD is None or not _TELEMETRY_THREAD.is_alive(): + _TELEMETRY_THREAD = Thread(target=_telemetry_worker, daemon=True) + _TELEMETRY_THREAD.start() + + +def _telemetry_worker(): + """Wait for a task and consume it.""" + while True: + kwargs = _TELEMETRY_QUEUE.get() + _send_telemetry_in_thread(**kwargs) + _TELEMETRY_QUEUE.task_done() + + +def _send_telemetry_in_thread( + topic: str, + *, + library_name: Optional[str] = None, + library_version: Optional[str] = None, + user_agent: Union[Dict, str, None] = None, +) -> None: + """Contains the actual data sending data to the Hub. + + This function is called directly in gradio's analytics because + it is not possible to send telemetry from a daemon thread. + + See here: https://github.com/gradio-app/gradio/pull/8180 + + Please do not rename or remove this function. + """ + path = "/".join(quote(part) for part in topic.split("/") if len(part) > 0) + try: + r = get_session().head( + f"{constants.ENDPOINT}/api/telemetry/{path}", + headers=build_hf_headers( + token=False, # no need to send a token for telemetry + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + ), + ) + hf_raise_for_status(r) + except Exception as e: + # We don't want to error in case of connection errors of any kind. + logger.debug(f"Error while sending telemetry: {e}") diff --git a/huggingface_hub/utils/_token.py b/huggingface_hub/utils/_token.py new file mode 100644 index 0000000000000000000000000000000000000000..1faae9bc9843ff245be37dbd9d80853392400eda --- /dev/null +++ b/huggingface_hub/utils/_token.py @@ -0,0 +1,131 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains an helper to get the token from machine (env variable, secret or config file).""" + +import os +import warnings +from pathlib import Path +from threading import Lock +from typing import Optional + +from .. import constants +from ._runtime import is_colab_enterprise, is_google_colab + + +_IS_GOOGLE_COLAB_CHECKED = False +_GOOGLE_COLAB_SECRET_LOCK = Lock() +_GOOGLE_COLAB_SECRET: Optional[str] = None + + +def get_token() -> Optional[str]: + """ + Get token if user is logged in. + + Note: in most cases, you should use [`huggingface_hub.utils.build_hf_headers`] instead. This method is only useful + if you want to retrieve the token for other purposes than sending an HTTP request. + + Token is retrieved in priority from the `HF_TOKEN` environment variable. Otherwise, we read the token file located + in the Hugging Face home folder. Returns None if user is not logged in. To log in, use [`login`] or + `huggingface-cli login`. + + Returns: + `str` or `None`: The token, `None` if it doesn't exist. + """ + return _get_token_from_google_colab() or _get_token_from_environment() or _get_token_from_file() + + +def _get_token_from_google_colab() -> Optional[str]: + """Get token from Google Colab secrets vault using `google.colab.userdata.get(...)`. + + Token is read from the vault only once per session and then stored in a global variable to avoid re-requesting + access to the vault. + """ + # If it's not a Google Colab or it's Colab Enterprise, fallback to environment variable or token file authentication + if not is_google_colab() or is_colab_enterprise(): + return None + + # `google.colab.userdata` is not thread-safe + # This can lead to a deadlock if multiple threads try to access it at the same time + # (typically when using `snapshot_download`) + # => use a lock + # See https://github.com/huggingface/huggingface_hub/issues/1952 for more details. + with _GOOGLE_COLAB_SECRET_LOCK: + global _GOOGLE_COLAB_SECRET + global _IS_GOOGLE_COLAB_CHECKED + + if _IS_GOOGLE_COLAB_CHECKED: # request access only once + return _GOOGLE_COLAB_SECRET + + try: + from google.colab import userdata + from google.colab.errors import Error as ColabError + except ImportError: + return None + + try: + token = userdata.get("HF_TOKEN") + _GOOGLE_COLAB_SECRET = _clean_token(token) + except userdata.NotebookAccessError: + # Means the user has a secret call `HF_TOKEN` and got a popup "please grand access to HF_TOKEN" and refused it + # => warn user but ignore error => do not re-request access to user + warnings.warn( + "\nAccess to the secret `HF_TOKEN` has not been granted on this notebook." + "\nYou will not be requested again." + "\nPlease restart the session if you want to be prompted again." + ) + _GOOGLE_COLAB_SECRET = None + except userdata.SecretNotFoundError: + # Means the user did not define a `HF_TOKEN` secret => warn + warnings.warn( + "\nThe secret `HF_TOKEN` does not exist in your Colab secrets." + "\nTo authenticate with the Hugging Face Hub, create a token in your settings tab " + "(https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session." + "\nYou will be able to reuse this secret in all of your notebooks." + "\nPlease note that authentication is recommended but still optional to access public models or datasets." + ) + _GOOGLE_COLAB_SECRET = None + except ColabError as e: + # Something happen but we don't know what => recommend to open a GitHub issue + warnings.warn( + f"\nError while fetching `HF_TOKEN` secret value from your vault: '{str(e)}'." + "\nYou are not authenticated with the Hugging Face Hub in this notebook." + "\nIf the error persists, please let us know by opening an issue on GitHub " + "(https://github.com/huggingface/huggingface_hub/issues/new)." + ) + _GOOGLE_COLAB_SECRET = None + + _IS_GOOGLE_COLAB_CHECKED = True + return _GOOGLE_COLAB_SECRET + + +def _get_token_from_environment() -> Optional[str]: + # `HF_TOKEN` has priority (keep `HUGGING_FACE_HUB_TOKEN` for backward compatibility) + return _clean_token(os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")) + + +def _get_token_from_file() -> Optional[str]: + try: + return _clean_token(Path(constants.HF_TOKEN_PATH).read_text()) + except FileNotFoundError: + return None + + +def _clean_token(token: Optional[str]) -> Optional[str]: + """Clean token by removing trailing and leading spaces and newlines. + + If token is an empty string, return None. + """ + if token is None: + return None + return token.replace("\r", "").replace("\n", "").strip() or None diff --git a/huggingface_hub/utils/_typing.py b/huggingface_hub/utils/_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..b28a68ec12489e024b486e4ff2178cc4096d314c --- /dev/null +++ b/huggingface_hub/utils/_typing.py @@ -0,0 +1,75 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Handle typing imports based on system compatibility.""" + +import sys +from typing import Any, Callable, List, Literal, Type, TypeVar, Union, get_args, get_origin + + +UNION_TYPES: List[Any] = [Union] +if sys.version_info >= (3, 10): + from types import UnionType + + UNION_TYPES += [UnionType] + + +HTTP_METHOD_T = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"] + +# type hint meaning "function signature not changed by decorator" +CallableT = TypeVar("CallableT", bound=Callable) + +_JSON_SERIALIZABLE_TYPES = (int, float, str, bool, type(None)) + + +def is_jsonable(obj: Any) -> bool: + """Check if an object is JSON serializable. + + This is a weak check, as it does not check for the actual JSON serialization, but only for the types of the object. + It works correctly for basic use cases but do not guarantee an exhaustive check. + + Object is considered to be recursively json serializable if: + - it is an instance of int, float, str, bool, or NoneType + - it is a list or tuple and all its items are json serializable + - it is a dict and all its keys are strings and all its values are json serializable + """ + try: + if isinstance(obj, _JSON_SERIALIZABLE_TYPES): + return True + if isinstance(obj, (list, tuple)): + return all(is_jsonable(item) for item in obj) + if isinstance(obj, dict): + return all(isinstance(key, str) and is_jsonable(value) for key, value in obj.items()) + if hasattr(obj, "__json__"): + return True + return False + except RecursionError: + return False + + +def is_simple_optional_type(type_: Type) -> bool: + """Check if a type is optional, i.e. Optional[Type] or Union[Type, None] or Type | None, where Type is a non-composite type.""" + if get_origin(type_) in UNION_TYPES: + union_args = get_args(type_) + if len(union_args) == 2 and type(None) in union_args: + return True + return False + + +def unwrap_simple_optional_type(optional_type: Type) -> Type: + """Unwraps a simple optional type, i.e. returns Type from Optional[Type].""" + for arg in get_args(optional_type): + if arg is not type(None): + return arg + raise ValueError(f"'{optional_type}' is not an optional type") diff --git a/huggingface_hub/utils/_validators.py b/huggingface_hub/utils/_validators.py new file mode 100644 index 0000000000000000000000000000000000000000..27833f28e3e2030680fb72b95a547521bc08831b --- /dev/null +++ b/huggingface_hub/utils/_validators.py @@ -0,0 +1,226 @@ +# coding=utf-8 +# Copyright 2022-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains utilities to validate argument values in `huggingface_hub`.""" + +import inspect +import re +import warnings +from functools import wraps +from itertools import chain +from typing import Any, Dict + +from huggingface_hub.errors import HFValidationError + +from ._typing import CallableT + + +REPO_ID_REGEX = re.compile( + r""" + ^ + (\b[\w\-.]+\b/)? # optional namespace (username or organization) + \b # starts with a word boundary + [\w\-.]{1,96} # repo_name: alphanumeric + . _ - + \b # ends with a word boundary + $ + """, + flags=re.VERBOSE, +) + + +def validate_hf_hub_args(fn: CallableT) -> CallableT: + """Validate values received as argument for any public method of `huggingface_hub`. + + The goal of this decorator is to harmonize validation of arguments reused + everywhere. By default, all defined validators are tested. + + Validators: + - [`~utils.validate_repo_id`]: `repo_id` must be `"repo_name"` + or `"namespace/repo_name"`. Namespace is a username or an organization. + - [`~utils.smoothly_deprecate_use_auth_token`]: Use `token` instead of + `use_auth_token` (only if `use_auth_token` is not expected by the decorated + function - in practice, always the case in `huggingface_hub`). + + Example: + ```py + >>> from huggingface_hub.utils import validate_hf_hub_args + + >>> @validate_hf_hub_args + ... def my_cool_method(repo_id: str): + ... print(repo_id) + + >>> my_cool_method(repo_id="valid_repo_id") + valid_repo_id + + >>> my_cool_method("other..repo..id") + huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. + + >>> my_cool_method(repo_id="other..repo..id") + huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. + + >>> @validate_hf_hub_args + ... def my_cool_auth_method(token: str): + ... print(token) + + >>> my_cool_auth_method(token="a token") + "a token" + + >>> my_cool_auth_method(use_auth_token="a use_auth_token") + "a use_auth_token" + + >>> my_cool_auth_method(token="a token", use_auth_token="a use_auth_token") + UserWarning: Both `token` and `use_auth_token` are passed (...) + "a token" + ``` + + Raises: + [`~utils.HFValidationError`]: + If an input is not valid. + """ + # TODO: add an argument to opt-out validation for specific argument? + signature = inspect.signature(fn) + + # Should the validator switch `use_auth_token` values to `token`? In practice, always + # True in `huggingface_hub`. Might not be the case in a downstream library. + check_use_auth_token = "use_auth_token" not in signature.parameters and "token" in signature.parameters + + @wraps(fn) + def _inner_fn(*args, **kwargs): + has_token = False + for arg_name, arg_value in chain( + zip(signature.parameters, args), # Args values + kwargs.items(), # Kwargs values + ): + if arg_name in ["repo_id", "from_id", "to_id"]: + validate_repo_id(arg_value) + + elif arg_name == "token" and arg_value is not None: + has_token = True + + if check_use_auth_token: + kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs) + + return fn(*args, **kwargs) + + return _inner_fn # type: ignore + + +def validate_repo_id(repo_id: str) -> None: + """Validate `repo_id` is valid. + + This is not meant to replace the proper validation made on the Hub but rather to + avoid local inconsistencies whenever possible (example: passing `repo_type` in the + `repo_id` is forbidden). + + Rules: + - Between 1 and 96 characters. + - Either "repo_name" or "namespace/repo_name" + - [a-zA-Z0-9] or "-", "_", "." + - "--" and ".." are forbidden + + Valid: `"foo"`, `"foo/bar"`, `"123"`, `"Foo-BAR_foo.bar123"` + + Not valid: `"datasets/foo/bar"`, `".repo_id"`, `"foo--bar"`, `"foo.git"` + + Example: + ```py + >>> from huggingface_hub.utils import validate_repo_id + >>> validate_repo_id(repo_id="valid_repo_id") + >>> validate_repo_id(repo_id="other..repo..id") + huggingface_hub.utils._validators.HFValidationError: Cannot have -- or .. in repo_id: 'other..repo..id'. + ``` + + Discussed in https://github.com/huggingface/huggingface_hub/issues/1008. + In moon-landing (internal repository): + - https://github.com/huggingface/moon-landing/blob/main/server/lib/Names.ts#L27 + - https://github.com/huggingface/moon-landing/blob/main/server/views/components/NewRepoForm/NewRepoForm.svelte#L138 + """ + if not isinstance(repo_id, str): + # Typically, a Path is not a repo_id + raise HFValidationError(f"Repo id must be a string, not {type(repo_id)}: '{repo_id}'.") + + if repo_id.count("/") > 1: + raise HFValidationError( + "Repo id must be in the form 'repo_name' or 'namespace/repo_name':" + f" '{repo_id}'. Use `repo_type` argument if needed." + ) + + if not REPO_ID_REGEX.match(repo_id): + raise HFValidationError( + "Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are" + " forbidden, '-' and '.' cannot start or end the name, max length is 96:" + f" '{repo_id}'." + ) + + if "--" in repo_id or ".." in repo_id: + raise HFValidationError(f"Cannot have -- or .. in repo_id: '{repo_id}'.") + + if repo_id.endswith(".git"): + raise HFValidationError(f"Repo_id cannot end by '.git': '{repo_id}'.") + + +def smoothly_deprecate_use_auth_token(fn_name: str, has_token: bool, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """Smoothly deprecate `use_auth_token` in the `huggingface_hub` codebase. + + The long-term goal is to remove any mention of `use_auth_token` in the codebase in + favor of a unique and less verbose `token` argument. This will be done a few steps: + + 0. Step 0: methods that require a read-access to the Hub use the `use_auth_token` + argument (`str`, `bool` or `None`). Methods requiring write-access have a `token` + argument (`str`, `None`). This implicit rule exists to be able to not send the + token when not necessary (`use_auth_token=False`) even if logged in. + + 1. Step 1: we want to harmonize everything and use `token` everywhere (supporting + `token=False` for read-only methods). In order not to break existing code, if + `use_auth_token` is passed to a function, the `use_auth_token` value is passed + as `token` instead, without any warning. + a. Corner case: if both `use_auth_token` and `token` values are passed, a warning + is thrown and the `use_auth_token` value is ignored. + + 2. Step 2: Once it is release, we should push downstream libraries to switch from + `use_auth_token` to `token` as much as possible, but without throwing a warning + (e.g. manually create issues on the corresponding repos). + + 3. Step 3: After a transitional period (6 months e.g. until April 2023?), we update + `huggingface_hub` to throw a warning on `use_auth_token`. Hopefully, very few + users will be impacted as it would have already been fixed. + In addition, unit tests in `huggingface_hub` must be adapted to expect warnings + to be thrown (but still use `use_auth_token` as before). + + 4. Step 4: After a normal deprecation cycle (3 releases ?), remove this validator. + `use_auth_token` will definitely not be supported. + In addition, we update unit tests in `huggingface_hub` to use `token` everywhere. + + This has been discussed in: + - https://github.com/huggingface/huggingface_hub/issues/1094. + - https://github.com/huggingface/huggingface_hub/pull/928 + - (related) https://github.com/huggingface/huggingface_hub/pull/1064 + """ + new_kwargs = kwargs.copy() # do not mutate input ! + + use_auth_token = new_kwargs.pop("use_auth_token", None) # remove from kwargs + if use_auth_token is not None: + if has_token: + warnings.warn( + "Both `token` and `use_auth_token` are passed to" + f" `{fn_name}` with non-None values. `token` is now the" + " preferred argument to pass a User Access Token." + " `use_auth_token` value will be ignored." + ) + else: + # `token` argument is not passed and a non-None value is passed in + # `use_auth_token` => use `use_auth_token` value as `token` kwarg. + new_kwargs["token"] = use_auth_token + + return new_kwargs diff --git a/huggingface_hub/utils/endpoint_helpers.py b/huggingface_hub/utils/endpoint_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..85cd86011b78bcdc57034aeebc3c01e9e721ab50 --- /dev/null +++ b/huggingface_hub/utils/endpoint_helpers.py @@ -0,0 +1,66 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Helpful utility functions and classes in relation to exploring API endpoints +with the aim for a user-friendly interface. +""" + +import math +import re +from typing import TYPE_CHECKING + +from ..repocard_data import ModelCardData + + +if TYPE_CHECKING: + from ..hf_api import ModelInfo + + +def _is_emission_within_threshold(model_info: "ModelInfo", minimum_threshold: float, maximum_threshold: float) -> bool: + """Checks if a model's emission is within a given threshold. + + Args: + model_info (`ModelInfo`): + A model info object containing the model's emission information. + minimum_threshold (`float`): + A minimum carbon threshold to filter by, such as 1. + maximum_threshold (`float`): + A maximum carbon threshold to filter by, such as 10. + + Returns: + `bool`: Whether the model's emission is within the given threshold. + """ + if minimum_threshold is None and maximum_threshold is None: + raise ValueError("Both `minimum_threshold` and `maximum_threshold` cannot both be `None`") + if minimum_threshold is None: + minimum_threshold = -1 + if maximum_threshold is None: + maximum_threshold = math.inf + + card_data = getattr(model_info, "card_data", None) + if card_data is None or not isinstance(card_data, (dict, ModelCardData)): + return False + + # Get CO2 emission metadata + emission = card_data.get("co2_eq_emissions", None) + if isinstance(emission, dict): + emission = emission["emissions"] + if not emission: + return False + + # Filter out if value is missing or out of range + matched = re.search(r"\d+\.\d+|\d+", str(emission)) + if matched is None: + return False + + emission_value = float(matched.group(0)) + return minimum_threshold <= emission_value <= maximum_threshold diff --git a/huggingface_hub/utils/insecure_hashlib.py b/huggingface_hub/utils/insecure_hashlib.py new file mode 100644 index 0000000000000000000000000000000000000000..f232ee0adcfc52dcc18b5ea4d9c913b206521f71 --- /dev/null +++ b/huggingface_hub/utils/insecure_hashlib.py @@ -0,0 +1,34 @@ +# Taken from https://github.com/mlflow/mlflow/pull/10119 +# +# DO NOT use this function for security purposes (e.g., password hashing). +# +# In Python >= 3.9, insecure hashing algorithms such as MD5 fail in FIPS-compliant +# environments unless `usedforsecurity=False` is explicitly passed. +# +# References: +# - https://github.com/mlflow/mlflow/issues/9905 +# - https://github.com/mlflow/mlflow/pull/10119 +# - https://docs.python.org/3/library/hashlib.html +# - https://github.com/huggingface/transformers/pull/27038 +# +# Usage: +# ```python +# # Use +# from huggingface_hub.utils.insecure_hashlib import sha256 +# # instead of +# from hashlib import sha256 +# +# # Use +# from huggingface_hub.utils import insecure_hashlib +# # instead of +# import hashlib +# ``` +import functools +import hashlib +import sys + + +_kwargs = {"usedforsecurity": False} if sys.version_info >= (3, 9) else {} +md5 = functools.partial(hashlib.md5, **_kwargs) +sha1 = functools.partial(hashlib.sha1, **_kwargs) +sha256 = functools.partial(hashlib.sha256, **_kwargs) diff --git a/huggingface_hub/utils/logging.py b/huggingface_hub/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..3aafdf148135397556b4bb762862377eafdafd14 --- /dev/null +++ b/huggingface_hub/utils/logging.py @@ -0,0 +1,182 @@ +# coding=utf-8 +# Copyright 2020 Optuna, Hugging Face +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Logging utilities.""" + +import logging +import os +from logging import ( + CRITICAL, # NOQA + DEBUG, # NOQA + ERROR, # NOQA + FATAL, # NOQA + INFO, # NOQA + NOTSET, # NOQA + WARN, # NOQA + WARNING, # NOQA +) +from typing import Optional + + +log_levels = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +_default_log_level = logging.WARNING + + +def _get_library_name() -> str: + return __name__.split(".")[0] + + +def _get_library_root_logger() -> logging.Logger: + return logging.getLogger(_get_library_name()) + + +def _get_default_logging_level(): + """ + If `HF_HUB_VERBOSITY` env var is set to one of the valid choices return that as the new default level. If it is not + - fall back to `_default_log_level` + """ + env_level_str = os.getenv("HF_HUB_VERBOSITY", None) + if env_level_str: + if env_level_str in log_levels: + return log_levels[env_level_str] + else: + logging.getLogger().warning( + f"Unknown option HF_HUB_VERBOSITY={env_level_str}, has to be one of: { ', '.join(log_levels.keys()) }" + ) + return _default_log_level + + +def _configure_library_root_logger() -> None: + library_root_logger = _get_library_root_logger() + library_root_logger.addHandler(logging.StreamHandler()) + library_root_logger.setLevel(_get_default_logging_level()) + + +def _reset_library_root_logger() -> None: + library_root_logger = _get_library_root_logger() + library_root_logger.setLevel(logging.NOTSET) + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Returns a logger with the specified name. This function is not supposed + to be directly accessed by library users. + + Args: + name (`str`, *optional*): + The name of the logger to get, usually the filename + + Example: + + ```python + >>> from huggingface_hub import get_logger + + >>> logger = get_logger(__file__) + >>> logger.set_verbosity_info() + ``` + """ + + if name is None: + name = _get_library_name() + + return logging.getLogger(name) + + +def get_verbosity() -> int: + """Return the current level for the HuggingFace Hub's root logger. + + Returns: + Logging level, e.g., `huggingface_hub.logging.DEBUG` and + `huggingface_hub.logging.INFO`. + ++ + HuggingFace Hub has following logging levels: + + - `huggingface_hub.logging.CRITICAL`, `huggingface_hub.logging.FATAL` + - `huggingface_hub.logging.ERROR` + - `huggingface_hub.logging.WARNING`, `huggingface_hub.logging.WARN` + - `huggingface_hub.logging.INFO` + - `huggingface_hub.logging.DEBUG` + + + """ + return _get_library_root_logger().getEffectiveLevel() + + +def set_verbosity(verbosity: int) -> None: + """ + Sets the level for the HuggingFace Hub's root logger. + + Args: + verbosity (`int`): + Logging level, e.g., `huggingface_hub.logging.DEBUG` and + `huggingface_hub.logging.INFO`. + """ + _get_library_root_logger().setLevel(verbosity) + + +def set_verbosity_info(): + """ + Sets the verbosity to `logging.INFO`. + """ + return set_verbosity(INFO) + + +def set_verbosity_warning(): + """ + Sets the verbosity to `logging.WARNING`. + """ + return set_verbosity(WARNING) + + +def set_verbosity_debug(): + """ + Sets the verbosity to `logging.DEBUG`. + """ + return set_verbosity(DEBUG) + + +def set_verbosity_error(): + """ + Sets the verbosity to `logging.ERROR`. + """ + return set_verbosity(ERROR) + + +def disable_propagation() -> None: + """ + Disable propagation of the library log outputs. Note that log propagation is + disabled by default. + """ + _get_library_root_logger().propagate = False + + +def enable_propagation() -> None: + """ + Enable propagation of the library log outputs. Please disable the + HuggingFace Hub's default handler to prevent double logging if the root + logger has been configured. + """ + _get_library_root_logger().propagate = True + + +_configure_library_root_logger() diff --git a/huggingface_hub/utils/sha.py b/huggingface_hub/utils/sha.py new file mode 100644 index 0000000000000000000000000000000000000000..001c3fe8b2f37a64e890888ca3d521c10ec8f03b --- /dev/null +++ b/huggingface_hub/utils/sha.py @@ -0,0 +1,64 @@ +"""Utilities to efficiently compute the SHA 256 hash of a bunch of bytes.""" + +from typing import BinaryIO, Optional + +from .insecure_hashlib import sha1, sha256 + + +def sha_fileobj(fileobj: BinaryIO, chunk_size: Optional[int] = None) -> bytes: + """ + Computes the sha256 hash of the given file object, by chunks of size `chunk_size`. + + Args: + fileobj (file-like object): + The File object to compute sha256 for, typically obtained with `open(path, "rb")` + chunk_size (`int`, *optional*): + The number of bytes to read from `fileobj` at once, defaults to 1MB. + + Returns: + `bytes`: `fileobj`'s sha256 hash as bytes + """ + chunk_size = chunk_size if chunk_size is not None else 1024 * 1024 + + sha = sha256() + while True: + chunk = fileobj.read(chunk_size) + sha.update(chunk) + if not chunk: + break + return sha.digest() + + +def git_hash(data: bytes) -> str: + """ + Computes the git-sha1 hash of the given bytes, using the same algorithm as git. + + This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object + for more details. + + Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the + pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of + the LFS file content when we want to compare LFS files. + + Args: + data (`bytes`): + The data to compute the git-hash for. + + Returns: + `str`: the git-hash of `data` as an hexadecimal string. + + Example: + ```python + >>> from huggingface_hub.utils.sha import git_hash + >>> git_hash(b"Hello, World!") + 'b45ef6fec89518d314f546fd6c3025367b721684' + ``` + """ + # Taken from https://gist.github.com/msabramo/763200 + # Note: no need to optimize by reading the file in chunks as we're not supposed to hash huge files (5MB maximum). + sha = sha1() + sha.update(b"blob ") + sha.update(str(len(data)).encode()) + sha.update(b"\0") + sha.update(data) + return sha.hexdigest() diff --git a/huggingface_hub/utils/tqdm.py b/huggingface_hub/utils/tqdm.py new file mode 100644 index 0000000000000000000000000000000000000000..dce7133b444436ce500ec31880d0cacd40eb558d --- /dev/null +++ b/huggingface_hub/utils/tqdm.py @@ -0,0 +1,264 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +"""Utility helpers to handle progress bars in `huggingface_hub`. + +Example: + 1. Use `huggingface_hub.utils.tqdm` as you would use `tqdm.tqdm` or `tqdm.auto.tqdm`. + 2. To disable progress bars, either use `disable_progress_bars()` helper or set the + environment variable `HF_HUB_DISABLE_PROGRESS_BARS` to 1. + 3. To re-enable progress bars, use `enable_progress_bars()`. + 4. To check whether progress bars are disabled, use `are_progress_bars_disabled()`. + +NOTE: Environment variable `HF_HUB_DISABLE_PROGRESS_BARS` has the priority. + +Example: + ```py + >>> from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars, tqdm + + # Disable progress bars globally + >>> disable_progress_bars() + + # Use as normal `tqdm` + >>> for _ in tqdm(range(5)): + ... pass + + # Still not showing progress bars, as `disable=False` is overwritten to `True`. + >>> for _ in tqdm(range(5), disable=False): + ... pass + + >>> are_progress_bars_disabled() + True + + # Re-enable progress bars globally + >>> enable_progress_bars() + + # Progress bar will be shown ! + >>> for _ in tqdm(range(5)): + ... pass + 100%|βββββββββββββββββββββββββββββββββββββββ| 5/5 [00:00<00:00, 117817.53it/s] + ``` + +Group-based control: + ```python + # Disable progress bars for a specific group + >>> disable_progress_bars("peft.foo") + + # Check state of different groups + >>> assert not are_progress_bars_disabled("peft")) + >>> assert not are_progress_bars_disabled("peft.something") + >>> assert are_progress_bars_disabled("peft.foo")) + >>> assert are_progress_bars_disabled("peft.foo.bar")) + + # Enable progress bars for a subgroup + >>> enable_progress_bars("peft.foo.bar") + + # Check if enabling a subgroup affects the parent group + >>> assert are_progress_bars_disabled("peft.foo")) + >>> assert not are_progress_bars_disabled("peft.foo.bar")) + + # No progress bar for `name="peft.foo"` + >>> for _ in tqdm(range(5), name="peft.foo"): + ... pass + + # Progress bar will be shown for `name="peft.foo.bar"` + >>> for _ in tqdm(range(5), name="peft.foo.bar"): + ... pass + 100%|βββββββββββββββββββββββββββββββββββββββ| 5/5 [00:00<00:00, 117817.53it/s] + + ``` +""" + +import io +import warnings +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, Iterator, Optional, Union + +from tqdm.auto import tqdm as old_tqdm + +from ..constants import HF_HUB_DISABLE_PROGRESS_BARS + + +# The `HF_HUB_DISABLE_PROGRESS_BARS` environment variable can be True, False, or not set (None), +# allowing for control over progress bar visibility. When set, this variable takes precedence +# over programmatic settings, dictating whether progress bars should be shown or hidden globally. +# Essentially, the environment variable's setting overrides any code-based configurations. +# +# If `HF_HUB_DISABLE_PROGRESS_BARS` is not defined (None), it implies that users can manage +# progress bar visibility through code. By default, progress bars are turned on. + + +progress_bar_states: Dict[str, bool] = {} + + +def disable_progress_bars(name: Optional[str] = None) -> None: + """ + Disable progress bars either globally or for a specified group. + + This function updates the state of progress bars based on a group name. + If no group name is provided, all progress bars are disabled. The operation + respects the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable's setting. + + Args: + name (`str`, *optional*): + The name of the group for which to disable the progress bars. If None, + progress bars are disabled globally. + + Raises: + Warning: If the environment variable precludes changes. + """ + if HF_HUB_DISABLE_PROGRESS_BARS is False: + warnings.warn( + "Cannot disable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=0` is set and has priority." + ) + return + + if name is None: + progress_bar_states.clear() + progress_bar_states["_global"] = False + else: + keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")] + for key in keys_to_remove: + del progress_bar_states[key] + progress_bar_states[name] = False + + +def enable_progress_bars(name: Optional[str] = None) -> None: + """ + Enable progress bars either globally or for a specified group. + + This function sets the progress bars to enabled for the specified group or globally + if no group is specified. The operation is subject to the `HF_HUB_DISABLE_PROGRESS_BARS` + environment setting. + + Args: + name (`str`, *optional*): + The name of the group for which to enable the progress bars. If None, + progress bars are enabled globally. + + Raises: + Warning: If the environment variable precludes changes. + """ + if HF_HUB_DISABLE_PROGRESS_BARS is True: + warnings.warn( + "Cannot enable progress bars: environment variable `HF_HUB_DISABLE_PROGRESS_BARS=1` is set and has priority." + ) + return + + if name is None: + progress_bar_states.clear() + progress_bar_states["_global"] = True + else: + keys_to_remove = [key for key in progress_bar_states if key.startswith(f"{name}.")] + for key in keys_to_remove: + del progress_bar_states[key] + progress_bar_states[name] = True + + +def are_progress_bars_disabled(name: Optional[str] = None) -> bool: + """ + Check if progress bars are disabled globally or for a specific group. + + This function returns whether progress bars are disabled for a given group or globally. + It checks the `HF_HUB_DISABLE_PROGRESS_BARS` environment variable first, then the programmatic + settings. + + Args: + name (`str`, *optional*): + The group name to check; if None, checks the global setting. + + Returns: + `bool`: True if progress bars are disabled, False otherwise. + """ + if HF_HUB_DISABLE_PROGRESS_BARS is True: + return True + + if name is None: + return not progress_bar_states.get("_global", True) + + while name: + if name in progress_bar_states: + return not progress_bar_states[name] + name = ".".join(name.split(".")[:-1]) + + return not progress_bar_states.get("_global", True) + + +class tqdm(old_tqdm): + """ + Class to override `disable` argument in case progress bars are globally disabled. + + Taken from https://github.com/tqdm/tqdm/issues/619#issuecomment-619639324. + """ + + def __init__(self, *args, **kwargs): + name = kwargs.pop("name", None) # do not pass `name` to `tqdm` + if are_progress_bars_disabled(name): + kwargs["disable"] = True + super().__init__(*args, **kwargs) + + def __delattr__(self, attr: str) -> None: + """Fix for https://github.com/huggingface/huggingface_hub/issues/1603""" + try: + super().__delattr__(attr) + except AttributeError: + if attr != "_lock": + raise + + +@contextmanager +def tqdm_stream_file(path: Union[Path, str]) -> Iterator[io.BufferedReader]: + """ + Open a file as binary and wrap the `read` method to display a progress bar when it's streamed. + + First implemented in `transformers` in 2019 but removed when switched to git-lfs. Used in `huggingface_hub` to show + progress bar when uploading an LFS file to the Hub. See github.com/huggingface/transformers/pull/2078#discussion_r354739608 + for implementation details. + + Note: currently implementation handles only files stored on disk as it is the most common use case. Could be + extended to stream any `BinaryIO` object but we might have to debug some corner cases. + + Example: + ```py + >>> with tqdm_stream_file("config.json") as f: + >>> requests.put(url, data=f) + config.json: 100%|βββββββββββββββββββββββββ| 8.19k/8.19k [00:02<00:00, 3.72kB/s] + ``` + """ + if isinstance(path, str): + path = Path(path) + + with path.open("rb") as f: + total_size = path.stat().st_size + pbar = tqdm( + unit="B", + unit_scale=True, + total=total_size, + initial=0, + desc=path.name, + ) + + f_read = f.read + + def _inner_read(size: Optional[int] = -1) -> bytes: + data = f_read(size) + pbar.update(len(data)) + return data + + f.read = _inner_read # type: ignore + + yield f + + pbar.close() diff --git a/requirements.txt b/requirements.txt index 136d55ba4597c49c4f1a599950a480ecf6a1eb13..dddd53f706557d606e4e586aba989d7e84e58f1a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,13 @@ --extra-index-url https://download.pytorch.org/whl/cu113 -torch +torch<1.13.0 torchvision autocuda findfile pyabsa -diffusers==0.10.0 +diffusers<0.11.0 scipy -git+https://github.com/huggingface/transformers.git +transformers<4.30.0 +huggingface-hub ftfy accelerate psutil @@ -18,3 +19,8 @@ opencv-python Pillow tqdm realesrgan +gradio +numpy<2.0.0 +gradio-client + + diff --git a/result.jpg b/result.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cd9fb494b0335617c452b59194334cf50442723e Binary files /dev/null and b/result.jpg differ diff --git a/scenery.png b/scenery.png new file mode 100644 index 0000000000000000000000000000000000000000..8e86e96487202dc2cfd3517cdfd612f9bf90381c --- /dev/null +++ b/scenery.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f316b848ba93cfbbf980ddee8a72ad6ca1ab463638cd5a72ef8d213685991f09 +size 1159989