diff options
-rw-r--r-- | src/syscall/mksyscall_windows.go | 79 | ||||
-rw-r--r-- | src/syscall/zsyscall_windows.go | 36 |
2 files changed, 100 insertions, 15 deletions
diff --git a/src/syscall/mksyscall_windows.go b/src/syscall/mksyscall_windows.go index 1cdd6b4d2..316e88d7e 100644 --- a/src/syscall/mksyscall_windows.go +++ b/src/syscall/mksyscall_windows.go @@ -138,8 +138,6 @@ func (p *Param) StringTmpVarCode() string { // TmpVarCode returns source code for temp variable. func (p *Param) TmpVarCode() string { switch { - case p.Type == "string": - return p.StringTmpVarCode() case p.Type == "bool": return p.BoolTmpVarCode() case strings.HasPrefix(p.Type, "[]"): @@ -149,19 +147,26 @@ func (p *Param) TmpVarCode() string { } } +// TmpVarHelperCode returns source code for helper's temp variable. +func (p *Param) TmpVarHelperCode() string { + if p.Type != "string" { + return "" + } + return p.StringTmpVarCode() +} + // SyscallArgList returns source code fragments representing p parameter // in syscall. Slices are translated into 2 syscall parameters: pointer to // the first element and length. func (p *Param) SyscallArgList() []string { + t := p.HelperType() var s string switch { - case p.Type[0] == '*': + case t[0] == '*': s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name) - case p.Type == "string": - s = fmt.Sprintf("unsafe.Pointer(%s)", p.tmpVar()) - case p.Type == "bool": + case t == "bool": s = p.tmpVar() - case strings.HasPrefix(p.Type, "[]"): + case strings.HasPrefix(t, "[]"): return []string{ fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()), fmt.Sprintf("uintptr(len(%s))", p.Name), @@ -177,6 +182,14 @@ func (p *Param) IsError() bool { return p.Name == "err" && p.Type == "error" } +// HelperType returns type of parameter p used in helper function. +func (p *Param) HelperType() string { + if p.Type == "string" { + return p.fn.StrconvType() + } + return p.Type +} + // join concatenates parameters ps into a string with sep separator. // Each parameter is converted into string by applying fn to it // before conversion. @@ -454,6 +467,11 @@ func (f *Fn) ParamList() string { return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ") } +// HelperParamList returns source code for helper function f parameters. +func (f *Fn) HelperParamList() string { + return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ") +} + // ParamPrintList returns source code of trace printing part correspondent // to syscall input parameters. func (f *Fn) ParamPrintList() string { @@ -510,6 +528,19 @@ func (f *Fn) SyscallParamList() string { return strings.Join(a, ", ") } +// HelperCallParamList returns source code of call into function f helper. +func (f *Fn) HelperCallParamList() string { + a := make([]string, 0, len(f.Params)) + for _, p := range f.Params { + s := p.Name + if p.Type == "string" { + s = p.tmpVar() + } + a = append(a, s) + } + return strings.Join(a, ", ") +} + // IsUTF16 is true, if f is W (utf16) function. It is false // for all A (ascii) functions. func (f *Fn) IsUTF16() bool { @@ -533,6 +564,25 @@ func (f *Fn) StrconvType() string { return "*byte" } +// HasStringParam is true, if f has at least one string parameter. +// Otherwise it is false. +func (f *Fn) HasStringParam() bool { + for _, p := range f.Params { + if p.Type == "string" { + return true + } + } + return false +} + +// HelperName returns name of function f helper. +func (f *Fn) HelperName() string { + if !f.HasStringParam() { + return f.Name + } + return "_" + f.Name +} + // Source files and functions. type Source struct { Funcs []*Fn @@ -666,7 +716,7 @@ import "syscall"{{end}} var ( {{template "dlls" .}} {{template "funcnames" .}}) -{{range .Funcs}}{{template "funcbody" .}}{{end}} +{{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}} {{end}} {{/* help functions */}} @@ -677,16 +727,27 @@ var ( {{define "funcnames"}}{{range .Funcs}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}") {{end}}{{end}} +{{define "helperbody"}} +func {{.Name}}({{.ParamList}}) {{template "results" .}}{ +{{template "helpertmpvars" .}} return {{.HelperName}}({{.HelperCallParamList}}) +} +{{end}} + {{define "funcbody"}} -func {{.Name}}({{.ParamList}}) {{if .Rets.List}}{{.Rets.List}} {{end}}{ +func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{ {{template "tmpvars" .}} {{template "syscall" .}} {{template "seterror" .}}{{template "printtrace" .}} return } {{end}} +{{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}} {{.TmpVarHelperCode}} +{{end}}{{end}}{{end}} + {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}} {{end}}{{end}}{{end}} +{{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}} + {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}} {{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}} diff --git a/src/syscall/zsyscall_windows.go b/src/syscall/zsyscall_windows.go index 1f44750b7..afc28f993 100644 --- a/src/syscall/zsyscall_windows.go +++ b/src/syscall/zsyscall_windows.go @@ -176,7 +176,11 @@ func LoadLibrary(libname string) (handle Handle, err error) { if err != nil { return } - r0, _, e1 := Syscall(procLoadLibraryW.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0) + return _LoadLibrary(_p0) +} + +func _LoadLibrary(libname *uint16) (handle Handle, err error) { + r0, _, e1 := Syscall(procLoadLibraryW.Addr(), 1, uintptr(unsafe.Pointer(libname)), 0, 0) handle = Handle(r0) if handle == 0 { if e1 != 0 { @@ -206,7 +210,11 @@ func GetProcAddress(module Handle, procname string) (proc uintptr, err error) { if err != nil { return } - r0, _, e1 := Syscall(procGetProcAddress.Addr(), 2, uintptr(module), uintptr(unsafe.Pointer(_p0)), 0) + return _GetProcAddress(module, _p0) +} + +func _GetProcAddress(module Handle, procname *byte) (proc uintptr, err error) { + r0, _, e1 := Syscall(procGetProcAddress.Addr(), 2, uintptr(module), uintptr(unsafe.Pointer(procname)), 0) proc = uintptr(r0) if proc == 0 { if e1 != 0 { @@ -1558,7 +1566,11 @@ func GetHostByName(name string) (h *Hostent, err error) { if err != nil { return } - r0, _, e1 := Syscall(procgethostbyname.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0) + return _GetHostByName(_p0) +} + +func _GetHostByName(name *byte) (h *Hostent, err error) { + r0, _, e1 := Syscall(procgethostbyname.Addr(), 1, uintptr(unsafe.Pointer(name)), 0, 0) h = (*Hostent)(unsafe.Pointer(r0)) if h == nil { if e1 != 0 { @@ -1581,7 +1593,11 @@ func GetServByName(name string, proto string) (s *Servent, err error) { if err != nil { return } - r0, _, e1 := Syscall(procgetservbyname.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(unsafe.Pointer(_p1)), 0) + return _GetServByName(_p0, _p1) +} + +func _GetServByName(name *byte, proto *byte) (s *Servent, err error) { + r0, _, e1 := Syscall(procgetservbyname.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(proto)), 0) s = (*Servent)(unsafe.Pointer(r0)) if s == nil { if e1 != 0 { @@ -1605,7 +1621,11 @@ func GetProtoByName(name string) (p *Protoent, err error) { if err != nil { return } - r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(_p0)), 0, 0) + return _GetProtoByName(_p0) +} + +func _GetProtoByName(name *byte) (p *Protoent, err error) { + r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(name)), 0, 0) p = (*Protoent)(unsafe.Pointer(r0)) if p == nil { if e1 != 0 { @@ -1623,7 +1643,11 @@ func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSR if status != nil { return } - r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(_p0)), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr))) + return _DnsQuery(_p0, qtype, options, extra, qrs, pr) +} + +func _DnsQuery(name *uint16, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status error) { + r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(name)), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr))) if r0 != 0 { status = Errno(r0) } |