|
package downloader |
|
|
|
import ( |
|
"encoding/json" |
|
"errors" |
|
"fmt" |
|
"io" |
|
"net/http" |
|
"strings" |
|
) |
|
|
|
type HuggingFaceScanResult struct { |
|
RepositoryId string `json:"repositoryId"` |
|
Revision string `json:"revision"` |
|
HasUnsafeFiles bool `json:"hasUnsafeFile"` |
|
ClamAVInfectedFiles []string `json:"clamAVInfectedFiles"` |
|
DangerousPickles []string `json:"dangerousPickles"` |
|
ScansDone bool `json:"scansDone"` |
|
} |
|
|
|
var ErrNonHuggingFaceFile = errors.New("not a huggingface repo") |
|
var ErrUnsafeFilesFound = errors.New("unsafe files found") |
|
|
|
func HuggingFaceScan(uri URI) (*HuggingFaceScanResult, error) { |
|
cleanParts := strings.Split(uri.ResolveURL(), "/") |
|
if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" { |
|
return nil, ErrNonHuggingFaceFile |
|
} |
|
results, err := http.Get(fmt.Sprintf("https://huggingface.co/api/models/%s/%s/scan", cleanParts[3], cleanParts[4])) |
|
if err != nil { |
|
return nil, err |
|
} |
|
if results.StatusCode != 200 { |
|
return nil, fmt.Errorf("unexpected status code during HuggingFaceScan: %d", results.StatusCode) |
|
} |
|
scanResult := &HuggingFaceScanResult{} |
|
bodyBytes, err := io.ReadAll(results.Body) |
|
if err != nil { |
|
return nil, err |
|
} |
|
err = json.Unmarshal(bodyBytes, scanResult) |
|
if err != nil { |
|
return nil, err |
|
} |
|
if scanResult.HasUnsafeFiles { |
|
return scanResult, ErrUnsafeFilesFound |
|
} |
|
return scanResult, nil |
|
} |
|
|