[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unquantized language models #248

Open
aminnasiri opened this issue Aug 8, 2024 · 0 comments
Open

Support unquantized language models #248

aminnasiri opened this issue Aug 8, 2024 · 0 comments
Labels
enhancement New feature or request Kalosm Related to the Kalosm library

Comments

@aminnasiri
Copy link
aminnasiri commented Aug 8, 2024

Specific Demand

HuggingFace's SafeTensors is a new, simple format for safely and quickly storing tensors. I would like to use it on a GPU server. I want to have the Cuda feature from Candle as a features to build it for the server and keep a quantized model for the development platform.

SafeTensors

Implement Suggestion

pub fn hub_load_safe_tensors(repo: &ApiRepo,
                             json_file: &str, ) -> anyhow::Result<Vec<std::path::PathBuf>> {
    let json_file = repo.get(json_file).map_err(candle_core::Error::wrap)?;
    let json_file = std::fs::File::open(json_file)?;
    let json: Weightmaps = serde_json::from_reader(&json_file).map_err(candle_core::Error::wrap)?;

    let pathbufs: Vec<std::path::PathBuf> = json
        .weight_map
        .iter()
        .map(|f| repo.get(f).unwrap())
        .collect();

    Ok(pathbufs)
}

fn deserialize_weight_map<'de, D>(deserializer: D) -> anyhow::Result<HashSet<String>, D::Error>
where
    D: Deserializer<'de>,
{
    let map = serde_json::Value::deserialize(deserializer)?;
    match map {
        serde_json::Value::Object(obj) => Ok(obj
            .values()
            .filter_map(|v| v.as_str().map(ToString::to_string))
            .collect::<HashSet<String>>()),
        _ => Err(serde::de::Error::custom(
            "Expected an object for weight_map",
        )),
    }
}

fn get_tokenizer(repo: &ApiRepo) -> anyhow::Result<Tokenizer> {
    let tokenizer_filename = repo.get("tokenizer.json")?;

    Tokenizer::from_file(tokenizer_filename).map_err(E::msg)
}


fn get_config(repo: &ApiRepo) -> anyhow::Result<Config> {
    let config_filename = repo.get("config.json")?;
    let read_file_optional = &std::fs::read(config_filename)?;

    info!("JSON content {}", String::from_utf8_lossy(&read_file_optional));
    
    let default_config = Config {
        vocab_size: 32064,
        hidden_act: Activation::Silu,
        hidden_size: 3072,
        intermediate_size: 8192,
        num_hidden_layers: 32,
        num_attention_heads: 32,
        num_key_value_heads: 32,
        rms_norm_eps: 1e-05,
        rope_theta: 10000.0,
        bos_token_id: Some(1),
        eos_token_id: Some(32000),
        rope_scaling: None,
        max_position_embeddings: 4096,
    };

    let config: Config = serde_json::from_slice(read_file_optional)
        .unwrap_or_else(|err| {
            error!("cannot decode: {}", err);
            // Return a default Config instance or handle the error accordingly
            default_config
        });

    // let config: Config = serde_json::from_slice(read_file_optional)?;
    Ok(config)
}

fn get_device() -> Device {
    let device_cuda = Device::new_cuda(0);
    let device_metal = Device::new_metal(0);

    let device = device_metal.or(device_cuda).unwrap_or(Device::Cpu);

    device
}

#[derive(Debug, Deserialize)]
struct Weightmaps {
    #[serde(deserialize_with = "deserialize_weight_map")]
    weight_map: HashSet<String>,
}

fn get_repo(token: String) -> anyhow::Result<ApiRepo> {
    let api = ApiBuilder::new().with_token(Some(token)).build()?;

    let model_id = "microsoft/Phi-3-mini-4k-instruct".to_string();

    Ok(api.repo(Repo::with_revision(
        model_id,
        RepoType::Model,
        "ba3e2e891adaf6b9e7471bcc80dec875d73ae4e9".to_string(),
    )))
}

pub fn initialise_model(token: String) -> anyhow::Result<AppState> {
    let repo = get_repo(token)?;
    let tokenizer = get_tokenizer(&repo)?;

    let device = get_device();

    let filenames = hub_load_safe_tensors(&repo, "model.safetensors.index.json")?;

    let config = get_config(&repo)?;

    let model = {
        let dtype = DType::F32;
        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
        Phi3::new(&config, vb)?
    };

    Ok((model, device, tokenizer, Some(0.7), config).into())
}
@ealmloff ealmloff added enhancement New feature or request Kalosm Related to the Kalosm library labels Aug 10, 2024
@ealmloff ealmloff changed the title Support HuggingFace Safetensors format Support unquantized language models Aug 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Kalosm Related to the Kalosm library
Projects
None yet
Development

No branches or pull requests

2 participants