8000 GitHub - sssam2021/treesampler
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

sssam2021/treesampler

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94E6
 
 
 

Repository files navigation

treesampler

[This is very much a work in progress, and is not ready for production use.]

A re-implementation of PICARD: Parsing Incrementally for Constrained Auto-Regressive Decoding from Language Models that can be applied to code generation for any language with LSP support.

LSP is a protocol that defines a common interface between a language server and a language client. It is used by editors like VSCode to provide code completion, hover, and other features. The protocol is language agnostic, so it can be used for any language, as long as a language server is available.

Installation

Support for each language is provided by a separate package. For example, to use treesampler for generating Python code, install treesampler[py]:

pip install git+https://github.com/nfcampos/treesampler.git#egg=treesampler[py]

For now only Python is supported, but more languages will be added in the future.

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer
from treesampler import LspDiagnosticsProcessor, with_lsp

def generate(prompt, **kwargs):
    with with_lsp("python", server_python_module="ruff_lsp") as lsp_client:
      checkpoint = "Salesforce/codegen-350M-mono"
      model = AutoModelForCausalLM.from_pretrained(checkpoint)
      tokenizer = AutoTokenizer.from_pretrained(checkpoint)
      processor = LspDiagnosticsProcessor(tokenizer, lsp_client)

      completion = model.generate(
          **tokenizer(prompt, return_tensors="pt"),
          logits_processor=[processor],
          **kwargs,
      )

      return tokenizer.decode(completion[0], skip_special_tokens=True)

How it works

The idea is to use an LSP server to parse the code incrementally as it is being generated, generate diagnostics (linter errors and warnings) and we can use these diagnostics to constrain the generation process.

For example, if the LSP server reports a syntax error, we can use this information to prevent sampling tokens that would cause the syntax error. If instead the LSP server reports a less severe warning, we can use this information to reduce the score of tokens that would cause the warning.

How to add support for a new language

  1. Find an existing LSP server for the language, this is a good place to look: https://microsoft.github.io/language-server-protocol/implementors/servers/

  2. Write a test that uses it, see tests/test_py.py for an example.

  3. Test the score adjustments produced by the base scorer, and optionally write a custom scorer for the language.

  4. Contributing a new language is very welcome, please open a PR!

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.8%
  • Makefile 2.2%
0